diff --git a/.clang-format b/.clang-format index c5ab0983b7530..5c0f059e15f3f 100644 --- a/.clang-format +++ b/.clang-format @@ -1,6 +1,5 @@ -BasedOnStyle: Chromium -ColumnLimit: 80 +BasedOnStyle: Google +ColumnLimit: 90 DerivePointerAlignment: false IndentCaseLabels: false PointerAlignment: Right -SpaceAfterCStyleCast: true diff --git a/.gitignore b/.gitignore index abd60923e6314..91189b6f9c41a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,26 +4,17 @@ /python/build /python/dist /python/flatbuffers-1.7.1/ -/src/common/thirdparty/redis -/src/thirdparty/arrow /flatbuffers-1.7.1/ -/src/thirdparty/boost/ -/src/thirdparty/boost_1_65_1/ -/src/thirdparty/boost_1_60_0/ -/src/thirdparty/catapult/ -/src/thirdparty/flatbuffers/ -/src/thirdparty/parquet-cpp /thirdparty/pkg/ # Files generated by flatc should be ignored -/src/common/format/*.py -/src/common/format/*_generated.h -/src/plasma/format/ -/src/local_scheduler/format/*_generated.h /src/ray/gcs/format/*_generated.h /src/ray/object_manager/format/*_generated.h /src/ray/raylet/format/*_generated.h +# Modin source files +/python/ray/modin + # Redis temporary files *dump.rdb @@ -54,9 +45,6 @@ python/.eggs *.dylib *.dll -# Cython-generated files -*.c - # Incremental linking files *.ilk diff --git a/.travis.yml b/.travis.yml index 47bef360e51e5..795dff67b6108 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,8 +53,8 @@ matrix: - sphinx-build -W -b html -d _build/doctrees source _build/html - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504 - - .travis/yapf.sh --all + - flake8 --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + - .travis/format.sh --all - os: linux dist: trusty @@ -69,16 +69,9 @@ matrix: script: - cd build - - bash ../src/common/test/run_valgrind.sh - - bash ../src/plasma/test/run_valgrind.sh - - bash ../src/local_scheduler/test/run_valgrind.sh - bash ../src/ray/test/run_object_manager_valgrind.sh - cd .. - - python ./python/ray/plasma/test/test.py valgrind - - python ./python/ray/local_scheduler/test/test.py valgrind - # - python ./python/ray/global_scheduler/test/test.py valgrind - # Build Linux wheels. - os: linux dist: trusty @@ -107,63 +100,6 @@ matrix: env: - PYTHON=3.5 - RAY_USE_NEW_GCS=on - - RAY_USE_XRAY=1 - - - os: linux - dist: trusty - env: PYTHON=3.5 RAY_USE_XRAY=1 - install: - - ./.travis/install-dependencies.sh - - export PATH="$HOME/miniconda/bin:$PATH" - - ./.travis/install-ray.sh - - ./.travis/install-cython-examples.sh - script: - - export PATH="$HOME/miniconda/bin:$PATH" - # The following is needed so cloudpickle can find some of the - # class definitions: The main module of tests that are run - # with pytest have the same name as the test file -- and this - # module is only found if the test directory is in the PYTHONPATH. - - export PYTHONPATH="$PYTHONPATH:./test/" - - - python -m pytest -v python/ray/common/test/test.py - - python -m pytest -v python/ray/common/redis_module/runtest.py - - python -m pytest -v python/ray/plasma/test/test.py - # - python -m pytest -v python/ray/local_scheduler/test/test.py - # - python -m pytest -v python/ray/global_scheduler/test/test.py - - - python -m pytest -v python/ray/test/test_global_state.py - - python -m pytest -v python/ray/test/test_queue.py - - python -m pytest -v test/xray_test.py - - - python -m pytest -v test/runtest.py - - python -m pytest -v test/array_test.py - - python -m pytest -v test/actor_test.py - - python -m pytest -v test/autoscaler_test.py - - python -m pytest -v test/tensorflow_test.py - - python -m pytest -v test/failure_test.py - - python -m pytest -v test/microbenchmarks.py - - python -m pytest -v test/stress_tests.py - - pytest test/component_failures_test.py - - python test/multi_node_test.py - - python -m pytest -v test/recursion_test.py - - pytest test/monitor_test.py - - python -m pytest -v test/cython_test.py - - python -m pytest -v test/credis_test.py - - # ray tune tests - - python python/ray/tune/test/dependency_test.py - - python -m pytest -v python/ray/tune/test/trial_runner_test.py - - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py - - python -m pytest -v python/ray/tune/test/experiment_test.py - - python -m pytest -v python/ray/tune/test/tune_server_test.py - - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py - - python -m pytest -v python/ray/tune/test/automl_searcher_test.py - - # ray rllib tests - - python -m pytest -v python/ray/rllib/test/test_catalog.py - - python -m pytest -v python/ray/rllib/test/test_filters.py - - python -m pytest -v python/ray/rllib/test/test_optimizers.py - - python -m pytest -v python/ray/rllib/test/test_evaluators.py install: @@ -181,12 +117,10 @@ install: - ./src/ray/raylet/lineage_cache_test - ./src/ray/raylet/task_dependency_manager_test - ./src/ray/raylet/reconstruction_policy_test + - ./src/ray/raylet/client_connection_test - ./src/ray/util/logging_test --gtest_filter=PrintLogTest* - ./src/ray/util/signal_test - - bash ../src/common/test/run_tests.sh - - bash ../src/plasma/test/run_tests.sh - - bash ../src/local_scheduler/test/run_tests.sh - cd .. script: @@ -197,14 +131,27 @@ script: # module is only found if the test directory is in the PYTHONPATH. - export PYTHONPATH="$PYTHONPATH:./test/" - - python -m pytest -v python/ray/common/test/test.py - - python -m pytest -v python/ray/common/redis_module/runtest.py - - python -m pytest -v python/ray/plasma/test/test.py - - python -m pytest -v python/ray/local_scheduler/test/test.py - - python -m pytest -v python/ray/global_scheduler/test/test.py + # ray tune tests + - python python/ray/tune/test/dependency_test.py + - python -m pytest -v python/ray/tune/test/trial_runner_test.py + - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py + - python -m pytest -v python/ray/tune/test/experiment_test.py + - python -m pytest -v python/ray/tune/test/tune_server_test.py + - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py + - python -m pytest -v python/ray/tune/test/automl_searcher_test.py + + # ray rllib tests + - python -m pytest -v python/ray/rllib/test/test_catalog.py + - python -m pytest -v python/ray/rllib/test/test_filters.py + - python -m pytest -v python/ray/rllib/test/test_optimizers.py + - python -m pytest -v python/ray/rllib/test/test_evaluators.py + + # Python3.5+ only. Otherwise we will get `SyntaxError` regardless of how we set the tester. + - python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v python/ray/experimental/test/async_test.py - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py + - python -m pytest -v python/ray/test/test_ray_init.py - python -m pytest -v test/xray_test.py - python -m pytest -v test/runtest.py @@ -216,26 +163,19 @@ script: - python -m pytest -v test/microbenchmarks.py - python -m pytest -v test/stress_tests.py - python -m pytest -v test/component_failures_test.py - - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test.py + - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - python -m pytest -v test/monitor_test.py - python -m pytest -v test/cython_test.py - python -m pytest -v test/credis_test.py + - python -m pytest -v test/node_manager_test.py - # ray tune tests - - python python/ray/tune/test/dependency_test.py - - python -m pytest -v python/ray/tune/test/trial_runner_test.py - - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py - - python -m pytest -v python/ray/tune/test/experiment_test.py - - python -m pytest -v python/ray/tune/test/tune_server_test.py - - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py - - python -m pytest -v python/ray/tune/test/automl_searcher_test.py + # ray temp file tests + - python -m pytest -v test/tempfile_test.py - # ray rllib tests - - python -m pytest -v python/ray/rllib/test/test_catalog.py - - python -m pytest -v python/ray/rllib/test/test_filters.py - - python -m pytest -v python/ray/rllib/test/test_optimizers.py - - python -m pytest -v python/ray/rllib/test/test_evaluators.py + # modin test files + - python python/ray/test/test_modin.py deploy: - provider: s3 diff --git a/.travis/yapf.sh b/.travis/format.sh similarity index 74% rename from .travis/yapf.sh rename to .travis/format.sh index d90aec89531d2..9313e641065a8 100755 --- a/.travis/yapf.sh +++ b/.travis/format.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails set -eo pipefail @@ -28,7 +30,6 @@ YAPF_EXCLUDES=( '--exclude' 'python/build/*' '--exclude' 'python/ray/pyarrow_files/*' '--exclude' 'python/ray/core/src/ray/gcs/*' - '--exclude' 'python/ray/common/thirdparty/*' ) # Format specified files @@ -50,6 +51,18 @@ format_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + if which flake8 >/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ + flake8 --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ \ + --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + fi + fi + + if which clang-format >/dev/null; then + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \ + clang-format -i + fi fi } diff --git a/.travis/install-dependencies.sh b/.travis/install-dependencies.sh index 1c6c3a342a616..5bae4ba87f8db 100755 --- a/.travis/install-dependencies.sh +++ b/.travis/install-dependencies.sh @@ -24,8 +24,8 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-Linux-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.22 requests \ - feather-format lxml openpyxl xlrd + pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y cmake pkg-config python-dev python-numpy build-essential autoconf curl libtool unzip @@ -33,8 +33,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.22 requests \ - feather-format lxml openpyxl xlrd + pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -50,8 +50,8 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.22 requests \ - feather-format lxml openpyxl xlrd + pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -67,8 +67,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.22 requests \ - feather-format lxml openpyxl xlrd + pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y cmake build-essential autoconf curl libtool unzip diff --git a/.travis/test-wheels.sh b/.travis/test-wheels.sh index 1b77209c3ddc7..f7870ea52d496 100755 --- a/.travis/test-wheels.sh +++ b/.travis/test-wheels.sh @@ -56,10 +56,10 @@ if [[ "$platform" == "linux" ]]; then # Check that the other wheels are present. NUMBER_OF_WHEELS=$(ls -1q $ROOT_DIR/../.whl/*.whl | wc -l) - if [[ "$NUMBER_OF_WHEELS" != "4" ]]; then + if [[ "$NUMBER_OF_WHEELS" != "5" ]]; then echo "Wrong number of wheels found." ls -l $ROOT_DIR/../.whl/ - exit 1 + exit 2 fi elif [[ "$platform" == "macosx" ]]; then @@ -67,12 +67,14 @@ elif [[ "$platform" == "macosx" ]]; then PY_MMS=("2.7" "3.4" "3.5" - "3.6") + "3.6" + "3.7") # This array is just used to find the right wheel. PY_WHEEL_VERSIONS=("27" "34" "35" - "36") + "36" + "37") for ((i=0; i<${#PY_MMS[@]}; ++i)); do PY_MM=${PY_MMS[i]} @@ -92,5 +94,5 @@ elif [[ "$platform" == "macosx" ]]; then done else echo "Unrecognized environment." - exit 1 + exit 3 fi diff --git a/CMakeLists.txt b/CMakeLists.txt index d02e88a5c4203..a6734e62ce144 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,18 +82,15 @@ include_directories(SYSTEM ${PLASMA_INCLUDE_DIR}) include_directories("${CMAKE_CURRENT_LIST_DIR}/src/") add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/ray/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/common/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/plasma/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/local_scheduler/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/global_scheduler/) # final target copy_ray add_custom_target(copy_ray ALL) # copy plasma_store_server add_custom_command(TARGET copy_ray POST_BUILD + COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma COMMAND ${CMAKE_COMMAND} -E - copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma) + copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # add pyarrow as the dependency @@ -102,12 +99,9 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # NOTE: The lists below must be kept in sync with ray/python/setup.py. set(ray_file_list - "src/common/thirdparty/redis/src/redis-server" - "src/common/redis_module/libray_redis_module.so" - "src/plasma/plasma_manager" - "src/local_scheduler/local_scheduler" - "src/local_scheduler/liblocal_scheduler_library_python.so" - "src/global_scheduler/global_scheduler" + "src/ray/thirdparty/redis/src/redis-server" + "src/ray/gcs/redis_module/libray_redis_module.so" + "src/ray/raylet/liblocal_scheduler_library_python.so" "src/ray/raylet/raylet_monitor" "src/ray/raylet/raylet") @@ -117,7 +111,10 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") list(APPEND ray_file_list "src/credis/redis/src/redis-server") endif() - if (DEFINED ENV{INCLUDE_UI} AND "$ENV{INCLUDE_UI}" STREQUAL "1") + # The goal of the if statement below is to require the catapult files to be + # present INCLUDE_UI=1 is set and to include the UI files if they are present. + # This should match the logic in build_ui.sh. + if (EXISTS "${CMAKE_BINARY_DIR}/src/catapult_files/index.html" OR "$ENV{INCLUDE_UI}" STREQUAL "1") list(APPEND ray_file_list "src/catapult_files/index.html") list(APPEND ray_file_list "src/catapult_files/trace_viewer_full.html") endif() @@ -154,5 +151,6 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") # copy libplasma_java files add_custom_command(TARGET copy_ray POST_BUILD - COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma") + COMMAND bash -c "mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma" + COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/") endif() diff --git a/README.rst b/README.rst index 356ef60ebf6f9..5fd892f95f037 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,6 @@ -Ray -=== +.. raw:: html + + .. image:: https://travis-ci.com/ray-project/ray.svg?branch=master :target: https://travis-ci.com/ray-project/ray @@ -7,9 +8,12 @@ Ray .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest +.. image:: https://img.shields.io/badge/pypi-0.6.0-blue.svg + :target: https://pypi.org/project/ray/ + | -Ray is a flexible, high-performance distributed execution framework. +**Ray is a flexible, high-performance distributed execution framework.** Ray is easy to install: ``pip install ray`` @@ -37,11 +41,12 @@ Example Use Ray comes with libraries that accelerate deep learning and reinforcement learning development: -- `Ray Tune`_: Hyperparameter Optimization Framework -- `Ray RLlib`_: Scalable Reinforcement Learning +- `Tune`_: Hyperparameter Optimization Framework +- `RLlib`_: Scalable Reinforcement Learning +- `Distributed Training `__ -.. _`Ray Tune`: http://ray.readthedocs.io/en/latest/tune.html -.. _`Ray RLlib`: http://ray.readthedocs.io/en/latest/rllib.html +.. _`Tune`: http://ray.readthedocs.io/en/latest/tune.html +.. _`RLlib`: http://ray.readthedocs.io/en/latest/rllib.html Installation ------------ diff --git a/build.sh b/build.sh index 496bbdddb5750..6aa695b83a924 100755 --- a/build.sh +++ b/build.sh @@ -25,7 +25,7 @@ function usage() # Determine how many parallel jobs to use for make based on the number of cores unamestr="$(uname)" if [[ "$unamestr" == "Linux" ]]; then - PARALLEL=$(nproc) + PARALLEL=$(nproc --all) elif [[ "$unamestr" == "Darwin" ]]; then PARALLEL=$(sysctl -n hw.ncpu) else @@ -101,12 +101,16 @@ fi pushd "$BUILD_DIR" +# avoid the command failed and exits +# and cmake will check some directories to determine whether some targets built +make clean || true +rm -rf external/arrow-install + cmake -DCMAKE_BUILD_TYPE=$CBUILD_TYPE \ -DCMAKE_RAY_LANG_JAVA=$RAY_BUILD_JAVA \ -DCMAKE_RAY_LANG_PYTHON=$RAY_BUILD_PYTHON \ -DRAY_USE_NEW_GCS=$RAY_USE_NEW_GCS \ -DPYTHON_EXECUTABLE:FILEPATH=$PYTHON_EXECUTABLE $ROOT_DIR -make clean make -j${PARALLEL} popd diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index dfb25f244f9a2..3e19dfbd2672f 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -9,25 +9,21 @@ # - ARROW_INCLUDE_DIR # - ARROW_SHARED_LIB # - ARROW_STATIC_LIB +# - ARROW_LIBRARY_DIR # - PLASMA_INCLUDE_DIR # - PLASMA_STATIC_LIB # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2522. We +# The PR for this commit is https://github.com/apache/arrow/pull/3093. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 7104d64ff2cd6c20e29d3cf4ec5c58bc10798f66) +set(arrow_TAG 187b98ed338d4995317dae9efd19870c532192cb) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) set(ARROW_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep) -# The following is needed because in CentOS, the lib directory is named lib64 -if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - set(LIB_SUFFIX 64) -endif() - set(ARROW_INCLUDE_DIR ${ARROW_HOME}/include) set(ARROW_LIBRARY_DIR ${ARROW_HOME}/lib${LIB_SUFFIX}) set(ARROW_SHARED_LIB ${ARROW_LIBRARY_DIR}/libarrow${CMAKE_SHARED_LIBRARY_SUFFIX}) @@ -58,7 +54,8 @@ set(ARROW_CMAKE_ARGS -DARROW_WITH_LZ4=off -DARROW_WITH_ZSTD=off -DFLATBUFFERS_HOME=${FLATBUFFERS_HOME} - -DBOOST_ROOT=${BOOST_ROOT}) + -DBOOST_ROOT=${BOOST_ROOT} + -DGLOG_HOME=${GLOG_HOME}) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # PyArrow needs following settings. @@ -92,19 +89,24 @@ endif() ExternalProject_Add(arrow_ep PREFIX external/arrow - DEPENDS flatbuffers boost + DEPENDS flatbuffers boost glog GIT_REPOSITORY ${arrow_URL} GIT_TAG ${arrow_TAG} + UPDATE_COMMAND "" ${ARROW_CONFIGURE} BUILD_BYPRODUCTS "${ARROW_SHARED_LIB}" "${ARROW_STATIC_LIB}") if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - ExternalProject_Add_Step(arrow_ep arrow_ep_install_java_lib - COMMAND bash -c "cd ${ARROW_SOURCE_DIR}/java && mvn clean install -pl plasma -am -Dmaven.test.skip > /dev/null" - DEPENDEES build) + set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES "${ARROW_SOURCE_DIR}/java/target/") + + if(NOT EXISTS ${ARROW_SOURCE_DIR}/java/target/) + ExternalProject_Add_Step(arrow_ep arrow_ep_install_java_lib + COMMAND bash -c "cd ${ARROW_SOURCE_DIR}/java && mvn clean install -pl plasma -am -Dmaven.test.skip > /dev/null" + DEPENDEES build) + endif() # add install of library plasma_java, it is not configured in plasma CMakeLists.txt ExternalProject_Add_Step(arrow_ep arrow_ep_install_plasma_java - COMMAND bash -c "cp ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep-build/release/libplasma_java.* ${ARROW_LIBRARY_DIR}/" + COMMAND bash -c "cp -rf ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep-build/release/libplasma_java.* ${ARROW_LIBRARY_DIR}/" DEPENDEES install) endif () diff --git a/cmake/Modules/BoostExternalProject.cmake b/cmake/Modules/BoostExternalProject.cmake index bab016a02b7a3..1fbbb0c0b58ef 100644 --- a/cmake/Modules/BoostExternalProject.cmake +++ b/cmake/Modules/BoostExternalProject.cmake @@ -9,9 +9,9 @@ # boost is a stable library in ray, and it supports to find # the boost pre-built in environment to speed up build process -if (DEFINED ENV{BOOST_ROOT} AND EXISTS ENV{BOOST_ROOT}) +if (DEFINED ENV{RAY_BOOST_ROOT} AND EXISTS $ENV{RAY_BOOST_ROOT}) set(Boost_USE_STATIC_LIBS ON) - set(BOOST_ROOT "$ENV{BOOST_ROOT}") + set(BOOST_ROOT "$ENV{RAY_BOOST_ROOT}") message(STATUS "Find BOOST_ROOT: ${BOOST_ROOT}") # find_package(Boost COMPONENTS system filesystem REQUIRED) set(Boost_INCLUDE_DIR ${BOOST_ROOT}/include) diff --git a/cmake/Modules/Common.cmake b/cmake/Modules/Common.cmake index cc2a5d5ff9926..7d33f13e9d450 100644 --- a/cmake/Modules/Common.cmake +++ b/cmake/Modules/Common.cmake @@ -41,6 +41,3 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") message (WARNING "NOT FIND JNI") endif() endif() - -include_directories(${CMAKE_SOURCE_DIR}/src/common) -include_directories(${CMAKE_SOURCE_DIR}/src/common/thirdparty) diff --git a/cmake/Modules/FlatBuffersExternalProject.cmake b/cmake/Modules/FlatBuffersExternalProject.cmake index 57c2216cecfb7..508010afced49 100644 --- a/cmake/Modules/FlatBuffersExternalProject.cmake +++ b/cmake/Modules/FlatBuffersExternalProject.cmake @@ -10,13 +10,8 @@ # - FLATBUFFERS_COMPILER # - FBS_DEPENDS, to keep compatible -# The following is needed because in CentOS, the lib directory is named lib64 -if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - set(LIB_SUFFIX 64) -endif() - -if(DEFINED ENV{FLATBUFFERS_HOME} AND EXISTS ENV{FLATBUFFERS_HOME}) - set(FLATBUFFERS_HOME "$ENV{FLATBUFFERS_HOME}") +if(DEFINED ENV{RAY_FLATBUFFERS_HOME} AND EXISTS $ENV{RAY_FLATBUFFERS_HOME}) + set(FLATBUFFERS_HOME "$ENV{RAY_FLATBUFFERS_HOME}") set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_HOME}/include") set(FLATBUFFERS_STATIC_LIB "${FLATBUFFERS_HOME}/lib${LIB_SUFFIX}/libflatbuffers.a") set(FLATBUFFERS_COMPILER "${FLATBUFFERS_HOME}/bin/flatc") diff --git a/cmake/Modules/GlogExternalProject.cmake b/cmake/Modules/GlogExternalProject.cmake index 47f11fbdbd6ad..2900bae4d523b 100644 --- a/cmake/Modules/GlogExternalProject.cmake +++ b/cmake/Modules/GlogExternalProject.cmake @@ -6,8 +6,8 @@ # - GLOG_INCLUDE_DIR # - GLOG_STATIC_LIB -if(DEFINED ENV{GLOG_HOME} AND EXISTS ENV{GLOG_HOME}) - set(GLOG_HOME "$ENV{GLOG_HOME}") +if(DEFINED ENV{RAY_GLOG_HOME} AND EXISTS $ENV{RAY_GLOG_HOME}) + set(GLOG_HOME "$ENV{RAY_GLOG_HOME}") set(GLOG_INCLUDE_DIR "${GLOG_HOME}/include") set(GLOG_STATIC_LIB "${GLOG_HOME}/lib/libglog.a") @@ -23,7 +23,7 @@ else() endif() set(GLOG_URL "https://github.com/google/glog/archive/v${GLOG_VERSION}.tar.gz") - set(GLOG_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/glog/src/glog_ep") + set(GLOG_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/glog-install") set(GLOG_HOME "${GLOG_PREFIX}") set(GLOG_INCLUDE_DIR "${GLOG_PREFIX}/include") set(GLOG_STATIC_LIB "${GLOG_PREFIX}/lib/libglog.a") diff --git a/cmake/Modules/GtestExternalProject.cmake b/cmake/Modules/GtestExternalProject.cmake index 5570066c60fbb..66e5a76f1d87e 100644 --- a/cmake/Modules/GtestExternalProject.cmake +++ b/cmake/Modules/GtestExternalProject.cmake @@ -7,18 +7,31 @@ # - GTEST_MAIN_STATIC_LIB # - GMOCK_MAIN_STATIC_LIB -if(DEFINED ENV{GTEST_HOME} AND EXISTS ENV{GTEST_HOME}) - set(GTEST_HOME "$ENV{GTEST_HOME}") - set(GTEST_INCLUDE_DIR "${GTEST_HOME}/include") - set(GTEST_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") - set(GTEST_MAIN_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") - set(GMOCK_MAIN_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX}") +set(GTEST_FOUND FALSE) + +if(DEFINED ENV{RAY_GTEST_HOME} AND EXISTS $ENV{RAY_GTEST_HOME}) + set(GTEST_HOME "$ENV{RAY_GTEST_HOME}") + find_path(GTEST_INCLUDE_DIR NAMES gtest/gtest.h + PATHS ${GTEST_HOME} NO_DEFAULT_PATH + PATH_SUFFIXES "include") + find_library(GTEST_LIBRARIES NAMES gtest gtest_main gmock_main + PATHS ${GTEST_HOME} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") + if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) + set(GTEST_FOUND TRUE) + set(GTEST_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GTEST_MAIN_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GMOCK_MAIN_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX}") + + add_custom_target(googletest_ep) + endif() + +endif() - add_custom_target(googletest_ep) -else() +if(NOT GTEST_FOUND) set(GTEST_VERSION "1.8.0") if(APPLE) @@ -31,7 +44,7 @@ else() endif() set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}} ${GTEST_CMAKE_CXX_FLAGS}") - set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/googletest/src/googletest_ep") + set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/googletest-install") set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") set(GTEST_STATIC_LIB "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") diff --git a/cmake/Modules/ThirdpartyToolchain.cmake b/cmake/Modules/ThirdpartyToolchain.cmake index 0e0553483ec23..723b3cd6aa001 100644 --- a/cmake/Modules/ThirdpartyToolchain.cmake +++ b/cmake/Modules/ThirdpartyToolchain.cmake @@ -4,6 +4,11 @@ # we have to turn it on for dependencies too set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +# The following is needed because in CentOS, the lib directory is named lib64 +if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) + set(LIB_SUFFIX 64) +endif() + if(RAY_BUILD_TESTS OR RAY_BUILD_BENCHMARKS) add_custom_target(unittest ctest -L unittest) @@ -25,18 +30,16 @@ if(RAY_BUILD_TESTS OR RAY_BUILD_BENCHMARKS) add_dependencies(gmock_main googletest_ep) endif() -if(RAY_USE_GLOG) - include(GlogExternalProject) - message(STATUS "Glog home: ${GLOG_HOME}") - message(STATUS "Glog include dir: ${GLOG_INCLUDE_DIR}") - message(STATUS "Glog static lib: ${GLOG_STATIC_LIB}") +include(GlogExternalProject) +message(STATUS "Glog home: ${GLOG_HOME}") +message(STATUS "Glog include dir: ${GLOG_INCLUDE_DIR}") +message(STATUS "Glog static lib: ${GLOG_STATIC_LIB}") - include_directories(${GLOG_INCLUDE_DIR}) - ADD_THIRDPARTY_LIB(glog - STATIC_LIB ${GLOG_STATIC_LIB}) +include_directories(${GLOG_INCLUDE_DIR}) +ADD_THIRDPARTY_LIB(glog + STATIC_LIB ${GLOG_STATIC_LIB}) - add_dependencies(glog glog_ep) -endif() +add_dependencies(glog glog_ep) # boost include(BoostExternalProject) @@ -95,19 +98,6 @@ ADD_THIRDPARTY_LIB(plasma STATIC_LIB ${PLASMA_STATIC_LIB}) add_dependencies(plasma plasma_ep) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - # pyarrow - find_package(PythonInterp REQUIRED) - message(STATUS "PYTHON_EXECUTABLE for pyarrow: ${PYTHON_EXECUTABLE}") - - set(pyarrow_ENV - "PKG_CONFIG_PATH=${ARROW_LIBRARY_DIR}/pkgconfig" - "PYARROW_WITH_PLASMA=1" - "PYARROW_WITH_TENSORFLOW=1" - "PYARROW_BUNDLE_ARROW_CPP=1" - "PARQUET_HOME=${PARQUET_HOME}" - "PYARROW_WITH_PARQUET=1" - ) - # clean the arrow_ep/python/build/lib.xxxxx directory, # or when you build with another python version, it creates multiple lib.xxxx directories set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES "${ARROW_SOURCE_DIR}/python/build/") @@ -115,13 +105,40 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # here we use externalProject to process pyarrow building # add_custom_command would have problem with setup.py - ExternalProject_Add(pyarrow_ext - PREFIX external/pyarrow - DEPENDS arrow_ep - DOWNLOAD_COMMAND "" - BUILD_IN_SOURCE 1 - CONFIGURE_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build - BUILD_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build_ext - INSTALL_COMMAND bash -c "cp -rf \$(find ${ARROW_SOURCE_DIR}/python/build/ -maxdepth 1 -type d -print | grep -m1 'lib')/pyarrow ${CMAKE_SOURCE_DIR}/python/ray/pyarrow_files/") + if(EXISTS ${ARROW_SOURCE_DIR}/python/build/) + # if we did not run `make clean`, skip the rebuild of pyarrow + add_custom_target(pyarrow_ext) + else() + # pyarrow + find_package(PythonInterp REQUIRED) + message(STATUS "PYTHON_EXECUTABLE for pyarrow: ${PYTHON_EXECUTABLE}") + + # PYARROW_PARALLEL= , so it will add -j to pyarrow build + set(pyarrow_ENV + "PKG_CONFIG_PATH=${ARROW_LIBRARY_DIR}/pkgconfig" + "PYARROW_WITH_PLASMA=1" + "PYARROW_WITH_TENSORFLOW=1" + "PYARROW_BUNDLE_ARROW_CPP=1" + "PARQUET_HOME=${PARQUET_HOME}" + "PYARROW_WITH_PARQUET=1" + "PYARROW_PARALLEL=") + + if (APPLE) + # Since 10.14, the XCode toolchain only accepts libc++ as the + # standard library. This should also work on macOS starting from 10.9. + set(pyarrow_ENV ${pyarrow_ENV} "CXXFLAGS='-stdlib=libc++'") + set(pyarrow_ENV ${pyarrow_ENV} "MACOSX_DEPLOYMENT_TARGET=10.7") + endif() + + ExternalProject_Add(pyarrow_ext + PREFIX external/pyarrow + DEPENDS arrow_ep + DOWNLOAD_COMMAND "" + BUILD_IN_SOURCE 1 + CONFIGURE_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build + BUILD_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build_ext + INSTALL_COMMAND bash -c "cp -rf \$(find ${ARROW_SOURCE_DIR}/python/build/ -maxdepth 1 -type d -print | grep -m1 'lib')/pyarrow ${CMAKE_SOURCE_DIR}/python/ray/pyarrow_files/") + + endif() endif () diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 5d953d3463400..f598baa081679 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -9,6 +9,7 @@ pyarrow pyyaml recommonmark redis +setproctitle sphinx sphinx-click sphinx_rtd_theme diff --git a/doc/source/actors.rst b/doc/source/actors.rst index c7594592f5124..0d8b3c94285b5 100644 --- a/doc/source/actors.rst +++ b/doc/source/actors.rst @@ -65,8 +65,7 @@ When ``a1.increment.remote()`` is called, the following events happens. 1. A task is created. 2. The task is assigned directly to the local scheduler responsible for the - actor by the driver's local scheduler. Thus, this scheduling procedure - bypasses the global scheduler. + actor by the driver's local scheduler. 3. An object ID is returned. We can then call ``ray.get`` on the object ID to retrieve the actual value. diff --git a/doc/source/async_api.rst b/doc/source/async_api.rst new file mode 100644 index 0000000000000..95867745f8ee6 --- /dev/null +++ b/doc/source/async_api.rst @@ -0,0 +1,87 @@ +Async API (Experimental) +======================== + +Since Python 3.5, it is possible to write concurrent code using the ``async/await`` `syntax `__. + +This document describes Ray's support for asyncio, which enables integration with popular async frameworks (e.g., aiohttp, aioredis, etc.) for high performance web and prediction serving. + +Starting Ray +------------ + +You must initialize Ray first. + +Please refer to `Starting Ray`_ for instructions. + +.. _`Starting Ray`: http://ray.readthedocs.io/en/latest/tutorial.html#starting-ray + + +Converting Ray objects into asyncio futures +------------------------------------------- + +Ray object IDs can be converted into asyncio futures with ``ray.experimental.async_api``. + +.. code-block:: python + + import asyncio + import time + import ray + from ray.experimental import async_api + + @ray.remote + def f(): + time.sleep(1) + return {'key1': ['value']} + + ray.init() + future = async_api.as_future(f.remote()) + asyncio.get_event_loop().run_until_complete(future) # {'key1': ['value']} + + +.. autofunction:: ray.experimental.async_api.as_future + + +Example Usage +------------- + ++----------------------------------------+-----------------------------------------------------+ +| **Basic Python** | **Distributed with Ray** | ++----------------------------------------+-----------------------------------------------------+ +| .. code-block:: python | .. code-block:: python | +| | | +| # Execute f serially. | # Execute f in parallel. | +| | | +| | | +| def f(): | @ray.remote | +| time.sleep(1) | def f(): | +| return 1 | time.sleep(1) | +| | return 1 | +| | | +| | ray.init() | +| results = [f() for i in range(4)] | results = ray.get([f.remote() for i in range(4)]) | ++----------------------------------------+-----------------------------------------------------+ +| **Async Python** | **Async Ray** | ++----------------------------------------+-----------------------------------------------------+ +| .. code-block:: python | .. code-block:: python | +| | | +| # Execute f asynchronously. | # Execute f asynchronously with Ray/asyncio. | +| | | +| | from ray.experimental import async_api | +| | | +| | @ray.remote | +| async def f(): | def f(): | +| await asyncio.sleep(1) | time.sleep(1) | +| return 1 | return 1 | +| | | +| | ray.init() | +| loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | +| tasks = [f() for i in range(4)] | tasks = [async_api.as_future(f.remote()) | +| | for i in range(4)] | +| results = loop.run_until_complete( | results = loop.run_until_complete( | +| asyncio.gather(tasks)) | asyncio.gather(tasks)) | ++----------------------------------------+-----------------------------------------------------+ + + +Known Issues +------------ + +Async API support is experimental, and we are working to improve its performance. Please `let us know `__ any issues you encounter. diff --git a/doc/source/autoscaling.rst b/doc/source/autoscaling.rst index 54ebcc350e5e0..90c8e92f3d278 100644 --- a/doc/source/autoscaling.rst +++ b/doc/source/autoscaling.rst @@ -76,6 +76,14 @@ You can use ``ray exec`` to conveniently run commands on clusters. Note that scr # Run a command in a screen (experimental) $ ray exec cluster.yaml 'echo "hello world"' --screen +You can also use ``ray submit`` to execute Python scripts on clusters. This will ``rsync`` the designated file onto the cluster and execute it with the given arguments. + +.. code-block:: bash + + # Run a Python script in a detached tmux session + $ ray submit cluster.yaml --tmux --start --stop tune_experiment.py + + Attaching to the cluster ------------------------ @@ -136,7 +144,8 @@ The default idle timeout is 5 minutes. This is to prevent excessive node churn w Monitoring cluster status ------------------------- -You can monitor cluster usage and auto-scaling status by tailing the autoscaling logs in ``/tmp/raylogs/monitor-*``. +You can monitor cluster usage and auto-scaling status by tailing the autoscaling +logs in ``/tmp/ray/session_*/logs/monitor*``. The Ray autoscaler also reports per-node status in the form of instance tags. In your cloud provider console, you can click on a Node, go the the "Tags" pane, and add the ``ray-node-status`` tag as a column. This lets you see per-node statuses at a glance: diff --git a/doc/source/conf.py b/doc/source/conf.py index 27d0c1200d9c9..2a2b1a37c207e 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -18,44 +18,41 @@ # These lines added to enable Sphinx to work without installing Ray. import mock -MOCK_MODULES = ["gym", - "gym.spaces", - "scipy", - "scipy.signal", - "tensorflow", - "tensorflow.contrib", - "tensorflow.contrib.layers", - "tensorflow.contrib.slim", - "tensorflow.contrib.rnn", - "tensorflow.core", - "tensorflow.core.util", - "tensorflow.python", - "tensorflow.python.client", - "tensorflow.python.util", - "ray.local_scheduler", - "ray.plasma", - "ray.core", - "ray.core.generated", - "ray.core.generated.DriverTableMessage", - "ray.core.generated.LocalSchedulerInfoMessage", - "ray.core.generated.ResultTableReply", - "ray.core.generated.SubscribeToDBClientTableReply", - "ray.core.generated.SubscribeToNotificationsReply", - "ray.core.generated.TaskInfo", - "ray.core.generated.TaskReply", - "ray.core.generated.TaskExecutionDependencies", - "ray.core.generated.ClientTableData", - "ray.core.generated.GcsTableEntry", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ray.protocol.Task", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub",] +MOCK_MODULES = [ + "gym", + "gym.spaces", + "scipy", + "scipy.signal", + "tensorflow", + "tensorflow.contrib", + "tensorflow.contrib.all_reduce", + "tensorflow.contrib.all_reduce.python", + "tensorflow.contrib.layers", + "tensorflow.contrib.slim", + "tensorflow.contrib.rnn", + "tensorflow.core", + "tensorflow.core.util", + "tensorflow.python", + "tensorflow.python.client", + "tensorflow.python.util", + "ray.raylet", + "ray.plasma", + "ray.core", + "ray.core.generated", + "ray.core.generated.ClientTableData", + "ray.core.generated.GcsTableEntry", + "ray.core.generated.HeartbeatTableData", + "ray.core.generated.HeartbeatBatchTableData", + "ray.core.generated.DriverTableData", + "ray.core.generated.ErrorTableData", + "ray.core.generated.ProfileTableData", + "ray.core.generated.ObjectTableData", + "ray.core.generated.ray.protocol.Task", + "ray.core.generated.TablePrefix", + "ray.core.generated.TablePubsub", +] for mod_name in MOCK_MODULES: - sys.modules[mod_name] = mock.Mock() + sys.modules[mod_name] = mock.Mock() # ray.rllib.models.action_dist.py and # ray.rllib.models.lstm.py will use tf.VERSION sys.modules["tensorflow"].VERSION = "9.9.9" @@ -89,7 +86,7 @@ source_suffix = ['.rst', '.md'] source_parsers = { - '.md': CommonMarkParser, + '.md': CommonMarkParser, } # The encoding of source files. @@ -259,25 +256,24 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # Additional stuff for the LaTeX preamble. + #'preamble': '', -# Latex figure (float) alignment -#'figure_align': 'htbp', + # Latex figure (float) alignment + #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Ray.tex', u'Ray Documentation', - u'The Ray Team', 'manual'), + (master_doc, 'Ray.tex', u'Ray Documentation', u'The Ray Team', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -300,29 +296,23 @@ # If false, no module index is generated. #latex_domain_indices = True - # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'ray', u'Ray Documentation', - [author], 1) -] +man_pages = [(master_doc, 'ray', u'Ray Documentation', [author], 1)] # If true, show URL addresses after external links. #man_show_urls = False - # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Ray', u'Ray Documentation', - author, 'Ray', 'One line description of project.', - 'Miscellaneous'), + (master_doc, 'Ray', u'Ray Documentation', author, 'Ray', + 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. diff --git a/doc/source/custom_metric.png b/doc/source/custom_metric.png new file mode 100644 index 0000000000000..3f448613711a3 Binary files /dev/null and b/doc/source/custom_metric.png differ diff --git a/doc/source/distributed_sgd.rst b/doc/source/distributed_sgd.rst new file mode 100644 index 0000000000000..5d1e480766258 --- /dev/null +++ b/doc/source/distributed_sgd.rst @@ -0,0 +1,56 @@ +Distributed SGD (Experimental) +============================== + +Ray includes an implementation of synchronous distributed stochastic gradient descent (SGD), which is competitive in performance with implementations in Horovod and Distributed TensorFlow. + +Ray SGD is built on top of the Ray task and actor abstractions to provide seamless integration into existing Ray applications. + +Interface +--------- + +To use Ray SGD, define a `model class `__ with ``loss`` and ``optimizer`` attributes: + +.. autoclass:: ray.experimental.sgd.Model + +Then, pass a model creator function to the ``ray.experimental.sgd.DistributedSGD`` class. To drive the distributed training, ``sgd.step()`` can be called repeatedly: + +.. code-block:: python + + model_creator = lambda worker_idx, device_idx: YourModelClass() + + sgd = DistributedSGD( + model_creator, + num_workers=2, + devices_per_worker=4, + gpu=True, + strategy="ps") + + for i in range(NUM_ITERS): + sgd.step() + +Under the hood, Ray SGD will create *replicas* of your model onto each hardware device (GPU) allocated to workers (controlled by ``num_workers``). Multiple devices can be managed by each worker process (controlled by ``devices_per_worker``). Each model instance will be in a separate TF variable scope. The ``DistributedSGD`` class coordinates the distributed computation and application of gradients to improve the model. + +There are two distributed SGD strategies available for use: + - ``strategy="simple"``: Gradients are averaged centrally on the driver before being applied to each model replica. This is a reference implementation for debugging purposes. + - ``strategy="ps"``: Gradients are computed and averaged within each node. Gradients are then averaged across nodes through a number of parameter server actors. To pipeline the computation of gradients and transmission across the network, we use a custom TensorFlow op that can read and write to the Ray object store directly. + +Note that when ``num_workers=1``, only local allreduce will be used and the choice of distributed strategy is irrelevant. + +The full documentation for ``DistributedSGD`` is as follows: + +.. autoclass:: ray.experimental.sgd.DistributedSGD + +Examples +-------- + +For examples of end-to-end usage, check out the `ImageNet synthetic data test `__ and also the simple `MNIST training example `__, which includes examples of how access the model weights and monitor accuracy as training progresses. + +Performance +----------- + +When using the new Ray backend (which will be enabled by default in Ray 0.6+), we `expect `__ performance competitive with other synchronous SGD implementations on 25Gbps Ethernet. + +.. figure:: sgd.png + :width: 756px + + Images per second reached when distributing the training of a ResNet-101 TensorFlow model (from the official TF benchmark). All experiments were run on p3.16xl instances connected by 25Gbps Ethernet, and workers allocated 4 GPUs per node as done in the Horovod benchmark. diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index 665d49a365512..47378fce9f915 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -9,11 +9,11 @@ View the `code for this example`_. .. _`A3C`: https://arxiv.org/abs/1602.01783 .. _`Universe Starter Agent`: https://github.com/openai/universe-starter-agent -.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/a3c +.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/a3c .. note:: - For an overview of Ray's reinforcement learning library, see `Ray RLlib `__. + For an overview of Ray's reinforcement learning library, see `RLlib `__. To run the application, first install **ray** and then some dependencies: @@ -29,7 +29,7 @@ You can run the code with .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}' + rllib train --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}' Reinforcement Learning ---------------------- diff --git a/doc/source/example-evolution-strategies.rst b/doc/source/example-evolution-strategies.rst index 16cdc3126d8f5..d048d261fff95 100644 --- a/doc/source/example-evolution-strategies.rst +++ b/doc/source/example-evolution-strategies.rst @@ -11,20 +11,20 @@ To run the application, first install some dependencies. You can view the `code for this example`_. -.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/es +.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/es The script can be run as follows. Note that the configuration is tuned to work on the ``Humanoid-v1`` gym environment. .. code-block:: bash - python/ray/rllib/train.py --env=Humanoid-v1 --run=ES + rllib train --env=Humanoid-v1 --run=ES To train a policy on a cluster (e.g., using 900 workers), run the following. .. code-block:: bash - python ray/python/ray/rllib/train.py \ + rllib train \ --env=Humanoid-v1 \ --run=ES \ --redis-address= \ diff --git a/doc/source/example-policy-gradient.rst b/doc/source/example-policy-gradient.rst index 806764560ba95..9b58575044c3b 100644 --- a/doc/source/example-policy-gradient.rst +++ b/doc/source/example-policy-gradient.rst @@ -6,7 +6,7 @@ View the `code for this example`_. .. note:: - For an overview of Ray's reinforcement learning library, see `Ray RLlib `__. + For an overview of Ray's reinforcement learning library, see `RLlib `__. To run this example, you will need to install `TensorFlow with GPU support`_ (at @@ -21,7 +21,7 @@ Then you can run the example as follows. .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO + rllib train --env=Pong-ram-v4 --run=PPO This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment. @@ -39,4 +39,4 @@ Many of the TensorBoard metrics are also printed to the console, but you might find it easier to visualize and compare between runs using the TensorBoard UI. .. _`TensorFlow with GPU support`: https://www.tensorflow.org/install/ -.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/ppo +.. _`code for this example`: https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/ppo diff --git a/doc/source/fault-tolerance.rst b/doc/source/fault-tolerance.rst index a4692f904feea..6c388c9f8883e 100644 --- a/doc/source/fault-tolerance.rst +++ b/doc/source/fault-tolerance.rst @@ -6,16 +6,9 @@ This document describes the handling of failures in Ray. Machine and Process Failures ---------------------------- -Currently, each **local scheduler** and each **plasma manager** send heartbeats -to a **monitor** process. If the monitor does not receive any heartbeats from a -given process for some duration of time (about ten seconds), then it will mark -that process as dead. The monitor process will then clean up the associated -state in the Redis servers. If a manager is marked as dead, the object table -will be updated to remove all occurrences of that manager so that other managers -don't try to fetch objects from the dead manager. If a local scheduler is marked -as dead, all of the tasks that are marked as executing on that local scheduler -in the task table will be marked as lost and all actors associated with that -local scheduler will be recreated by other local schedulers. +Each **raylet** (the scheduler process) sends heartbeats to a **monitor** +process. If the monitor does not receive any heartbeats from a given raylet for +some period of time (about ten seconds), then it will mark that process as dead. Lost Objects ------------ @@ -23,19 +16,16 @@ Lost Objects If an object is needed but is lost or was never created, then the task that created the object will be re-executed to create the object. If necessary, tasks needed to create the input arguments to the task being re-executed will also be -re-executed. +re-executed. This is the standard *lineage-based fault tolerance* strategy used +by other systems like Spark. Actors ------ -When a local scheduler is marked as dead, all actors associated with that local -scheduler that were still alive will be recreated by other local schedulers. By -default, all of the actor methods will be re-executed in the same order that -they were initially executed. If actor checkpointing is enabled, then the actor -state will be loaded from the most recent checkpoint and the actor methods that -occurred after the checkpoint will be re-executed. Note that actor checkpointing -is currently an experimental feature. - +When an actor dies (either because the actor process crashed or because the node +that the actor was on died), by default any attempt to get an object from that +actor that cannot be created will raise an exception. Subsequent releases will +include an option for automatically restarting actors. Current Limitations ------------------- @@ -47,7 +37,7 @@ Process Failures ~~~~~~~~~~~~~~~~ 1. Ray does not recover from the failure of any of the following processes: - a Redis server, the global scheduler, the monitor process. + a Redis server and the monitor process. 2. If a driver fails, that driver will not be restarted and the job will not complete. @@ -58,9 +48,3 @@ Lost Objects evicted, and is later needed, Ray will not reconstruct this object. 2. If an object is constructed by an actor method, is then evicted, and is later needed, Ray will not reconstruct this object. - -Actor Reconstruction -~~~~~~~~~~~~~~~~~~~~ - -1. Actor reconstruction follows the order of initial execution, but new tasks - may get interleaved with the re-executed tasks. diff --git a/doc/source/images/ray_logo.png b/doc/source/images/ray_logo.png new file mode 100644 index 0000000000000..05840a7ff453e Binary files /dev/null and b/doc/source/images/ray_logo.png differ diff --git a/doc/source/impala.png b/doc/source/impala.png index a7d12e4b5a0f9..0d42fe6e07dc9 100644 Binary files a/doc/source/impala.png and b/doc/source/impala.png differ diff --git a/doc/source/index.rst b/doc/source/index.rst index b71987108be05..68a33676c80d8 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -42,6 +42,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin - `Tune`_: Scalable Hyperparameter Search - `RLlib`_: Scalable Reinforcement Learning +- `Distributed Training `__ .. _`Tune`: tune.html .. _`RLlib`: rllib.html @@ -64,6 +65,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin actors.rst using-ray-with-gpus.rst webui.rst + async_api.rst .. toctree:: :maxdepth: 1 @@ -74,10 +76,11 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin tune-schedulers.rst tune-searchalg.rst tune-package-ref.rst + tune-examples.rst .. toctree:: :maxdepth: 1 - :caption: Ray RLlib + :caption: RLlib rllib.rst rllib-training.rst @@ -89,8 +92,9 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin .. toctree:: :maxdepth: 1 - :caption: Pandas on Ray + :caption: Other Libraries + distributed_sgd.rst pandas_on_ray.rst .. toctree:: @@ -118,6 +122,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin plasma-object-store.rst resources.rst redis-memory-management.rst + tempfile.rst .. toctree:: :maxdepth: 1 @@ -134,6 +139,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin troubleshooting.rst user-profiling.rst + security.rst development.rst profiling.rst contact.rst diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 4c4bc3f165ef7..68bd37ae96f5d 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -3,17 +3,23 @@ Installing Ray Ray should work with Python 2 and Python 3. We have tested Ray on Ubuntu 14.04, Ubuntu 16.04, OS X 10.11 and 10.12. -You can install Ray as follows. +Latest stable version +--------------------- + +You can install the latest stable version of Ray as follows. .. code-block:: bash - pip install ray + pip install -U ray # also recommended: ray[debug] + +Trying snapshots from master +---------------------------- + +Here are links to the latest wheels (which are built off of master). To install these wheels, run the following command: -Trying the latest version of Ray --------------------------------- +.. danger:: -Here are links to the latest wheels (which are built off of master). These versions will have newer -features but may be subject to more bugs. To install these wheels, run the following command: + These versions will have newer features but are subject to more bugs. If you encounter crashes or other instabilities, please revert to the latest stable version. .. code-block:: bash @@ -23,6 +29,7 @@ features but may be subject to more bugs. To install these wheels, run the follo =================== =================== Linux MacOS =================== =================== +`Linux Python 3.7`_ `MacOS Python 3.7`_ `Linux Python 3.6`_ `MacOS Python 3.6`_ `Linux Python 3.5`_ `MacOS Python 3.5`_ `Linux Python 3.4`_ `MacOS Python 3.4`_ @@ -30,14 +37,16 @@ features but may be subject to more bugs. To install these wheels, run the follo =================== =================== -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp34-cp34m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp34-cp34m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp34-cp34m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp34-cp34m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source @@ -67,7 +76,7 @@ For Ubuntu, run the following commands: # If you are on Ubuntu 14.04, you need the following. pip install cmake - pip install cython + pip install cython==0.27.3 For MacOS, run the following commands: @@ -76,7 +85,7 @@ For MacOS, run the following commands: brew update brew install cmake pkg-config automake autoconf libtool openssl bison wget - pip install cython + pip install cython==0.27.3 If you are using Anaconda, you may also need to run the following. diff --git a/doc/source/internals-overview.rst b/doc/source/internals-overview.rst index 69ac1895a55c7..a2516de1d10ce 100644 --- a/doc/source/internals-overview.rst +++ b/doc/source/internals-overview.rst @@ -15,8 +15,8 @@ Running Ray standalone Ray can be used standalone by calling ``ray.init()`` within a script. When the call to ``ray.init()`` happens, all of the relevant processes are started. -These include a local scheduler, a global scheduler, an object store and -manager, a Redis server, and a number of worker processes. +These include a local scheduler, an object store and manager, a Redis server, +and a number of worker processes. When the script exits, these processes will be killed. @@ -112,7 +112,7 @@ When a driver or worker invokes a remote function, a number of things happen. - The task object is then sent to the local scheduler on the same node as the driver or worker. - The local scheduler makes a decision to either schedule the task locally or to - pass the task on to a global scheduler. + pass the task on to another local scheduler. - If all of the task's object dependencies are present in the local object store and there are enough CPU and GPU resources available to execute the diff --git a/doc/source/profiling.rst b/doc/source/profiling.rst index 59d12d635cdeb..55ed8de6fae2b 100644 --- a/doc/source/profiling.rst +++ b/doc/source/profiling.rst @@ -14,54 +14,20 @@ symbolize on Mac OS have failed. sudo apt-get install google-perftools libgoogle-perftools-dev -Changes to compilation and linking ----------------------------------- - -Let's say we want to profile the ``plasma_manager``. Change the link -instruction in ``src/plasma/CMakeLists.txt`` from - -.. code-block:: cmake - - target_link_libraries(plasma_manager common ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread) - -to additionally include ``-lprofiler``: - -.. code-block:: cmake - - target_link_libraries(plasma_manager common ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread -lprofiler) - -Additionally, add ``-g -ggdb`` to ``CMAKE_C_FLAGS`` and ``CMAKE_CXX_FLAGS`` to -enable the debug symbols. (Keeping ``-O3`` seems okay.) - -Recompile. - Launching the to-profile binary ------------------------------- -In various places, instead of launching the target binary via -``plasma_manager ``, it must be launched with +If you want to launch Ray in profiling mode, define the following variables: .. code-block:: bash - LD_PRELOAD=/usr/lib/libprofiler.so CPUPROFILE=/tmp/pprof.out plasma_manager - -In practice, this means modifying ``python/ray/plasma/plasma.py`` so that the -manager is launched with a command that passes a ``modified_env`` into -``Popen``. - -.. code-block:: python - - modified_env = os.environ.copy() - modified_env["LD_PRELOAD"] = "/usr/lib/libprofiler.so" - modified_env["CPUPROFILE"] = "/tmp/pprof.out" + export RAYLET_PERFTOOLS_PATH=/usr/lib/x86_64-linux-gnu/libprofiler.so + export RAYLET_PERFTOOLS_LOGFILE=/tmp/pprof.out - process = subprocess.Popen(command, - stdout=stdout_file, - stderr=stderr_file, - env=modified_env) The file ``/tmp/pprof.out`` will be empty until you let the binary run the -target workload for a while and then ``kill`` it. +target workload for a while and then ``kill`` it via ``ray stop`` or by +letting the driver exit. Visualizing the CPU profile --------------------------- @@ -72,14 +38,14 @@ zoomable ``.svg`` image displaying the call graph annotated with hot paths. .. code-block:: bash # Use the appropriate path. - PLASMA_MANAGER=ray/python/ray/core/src/plasma/plasma_manager + RAYLET=ray/python/ray/core/src/ray/raylet/raylet - google-pprof -svg $PLASMA_MANAGER /tmp/pprof.out > /tmp/pprof.svg + google-pprof -svg $RAYLET /tmp/pprof.out > /tmp/pprof.svg # Then open the .svg file with Chrome. # If you realize the call graph is too large, use -focus= to zoom # into subtrees. - google-pprof -focus=epoll_wait -svg $PLASMA_MANAGER /tmp/pprof.out > /tmp/pprof.svg + google-pprof -focus=epoll_wait -svg $RAYLET /tmp/pprof.out > /tmp/pprof.svg Here's a snapshot of an example svg output, taken from the official documentation: diff --git a/doc/source/redis-memory-management.rst b/doc/source/redis-memory-management.rst index 64d2035ed0f31..5e6edcc02f6c4 100644 --- a/doc/source/redis-memory-management.rst +++ b/doc/source/redis-memory-management.rst @@ -1,4 +1,4 @@ -Redis Memory Management (EXPERIMENTAL) +Redis Memory Management (Experimental) ====================================== Ray stores metadata associated with tasks and objects in one or more Redis @@ -7,92 +7,9 @@ servers, as described in `An Overview of the Internals task/object generation rate could risk high memory pressure, potentially leading to out-of-memory (OOM) errors. -Here, we describe an experimental feature that transparently flushes metadata -entries out of Redis memory. +In Ray `0.6.1+` Redis shards can be configured to LRU evict task and object +metadata by setting ``redis_max_memory`` when starting Ray. This supercedes the +previously documented flushing functionality. -Requirements ------------- - -As of early July 2018, the automatic memory management feature requires building -Ray from source. We are planning on eliminating this step in the near future by -releasing official wheels. - -Building Ray -~~~~~~~~~~~~ - -First, follow `instructions to build Ray from source -`__ to install prerequisites. After -the prerequisites are installed, instead of doing the regular ``pip install`` as -referenced in that document, pass an additional special flag, -``RAY_USE_NEW_GCS=on``: - -.. code-block:: bash - - git clone https://github.com/ray-project/ray.git - cd ray/python - RAY_USE_NEW_GCS=on pip install -e . --verbose # Add --user if you see a permission denied error. - -Running Ray applications -~~~~~~~~~~~~~~~~~~~~~~~~ - -At run time the environment variables ``RAY_USE_NEW_GCS=on`` and -``RAY_USE_XRAY=1`` are required. - -.. code-block:: bash - - export RAY_USE_NEW_GCS=on - export RAY_USE_XRAY=1 - python my_ray_script.py # Or launch python/ipython. - -Activate memory flushing ------------------------- - -After building Ray using the method above, simply add these two lines after -``ray.init()`` to activate automatic memory flushing: - -.. code-block:: python - - ray.init(...) - - policy = ray.experimental.SimpleGcsFlushPolicy() - ray.experimental.set_flushing_policy(policy) - - # My awesome Ray application logic follows. - -Paramaters of the flushing policy ---------------------------------- - -There are three `user-configurable parameters -`_ -of the ``SimpleGcsFlushPolicy``: - -* ``flush_when_at_least_bytes``: Wait until this many bytes of memory usage - accumulated in the redis server before flushing kicks in. -* ``flush_period_secs``: Issue a flush to the Redis server every this many - seconds. -* ``flush_num_entries_each_time``: A hint to the system on the number of entries - to flush on each request. - -The default values should serve to be non-invasive for lightweight Ray -applications. ``flush_when_at_least_bytes`` is set to ``(1<<31)`` or 2GB, -``flush_period_secs`` to 10, and ``flush_num_entries_each_time`` to 10000: - -.. code-block:: python - - # Default parameters. - ray.experimental.SimpleGcsFlushPolicy( - flush_when_at_least_bytes=(1 << 31), - flush_period_secs=10, - flush_num_entries_each_time=10000) - -In particular, these default values imply that - -1. the Redis server would accumulate memory usage up to 2GB without any entries -being flushed, then the flushing would kick in; and - -2. generally, "older" metadata entries would be flushed first, and the Redis -server would always keep the most recent window of metadata of 2GB in size. - -**For advanced users.** Advanced users can tune the above parameters to their -applications' needs; note that the desired flush rate is equal to (flush -period) * (num entries each flush). +Note that profiling is disabled when ``redis_max_memory`` is set. This is because +profiling data cannot be LRU evicted. diff --git a/doc/source/resources.rst b/doc/source/resources.rst index e0dc9d742ec28..4be2f61afbe4b 100644 --- a/doc/source/resources.rst +++ b/doc/source/resources.rst @@ -1,5 +1,5 @@ -Resource (CPUs, GPUs) -===================== +Resources (CPUs, GPUs) +====================== This document describes how resources are managed in Ray. Each node in a Ray cluster knows its own resource capacities, and each task specifies its resource @@ -39,7 +39,8 @@ Specifying a task's CPU and GPU requirements ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To specify a task's CPU and GPU requirements, pass the ``num_cpus`` and -``num_gpus`` arguments into the remote decorator. +``num_gpus`` arguments into the remote decorator. Note that Ray supports +**fractional** resource requirements. .. code-block:: python @@ -47,7 +48,11 @@ To specify a task's CPU and GPU requirements, pass the ``num_cpus`` and def f(): return 1 -When ``f`` tasks will be scheduled on machines that have at least 4 CPUs and 2 + @ray.remote(num_gpus=0.5) + def h(): + return 1 + +The ``f`` tasks will be scheduled on machines that have at least 4 CPUs and 2 GPUs, and when one of the ``f`` tasks executes, 4 CPUs and 2 GPUs will be reserved for that task. The IDs of the GPUs that are reserved for the task can be accessed with ``ray.get_gpu_ids()``. Ray will automatically set the @@ -108,3 +113,9 @@ decorator. @ray.remote(resources={'Resource2': 1}) def f(): return 1 + +Fractional Resources +-------------------- + +Task and actor resource requirements can be fractional. This is particularly +useful if you want multiple tasks or actors to share a single GPU. diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index d764fc7ad8ea3..1d0501215745c 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -38,14 +38,21 @@ SpaceInvaders 646 ~300 Ape-X using 32 workers in RLlib vs vanilla DQN (orange) and A3C (blue) on PongNoFrameskip-v4. +**Ape-X specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/dqn/apex.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Importance Weighted Actor-Learner Architecture (IMPALA) ------------------------------------------------------- `[paper] `__ `[implementation] `__ -In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. +In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. Multiple learner GPUs and experience replay are also supported. -Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__ +Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__, `multi-gpu configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__ **Atari results @10M steps**: `more details `__ @@ -71,7 +78,15 @@ SpaceInvaders 843 ~300 .. figure:: impala.png - IMPALA solves Atari several times faster than A2C / A3C, with similar sample efficiency. Here IMPALA scales from 16 to 128 workers to solve PongNoFrameskip-v4 in ~8 minutes. + Multi-GPU IMPALA scales up to solve PongNoFrameskip-v4 in ~3 minutes using a pair of V100 GPUs and 128 CPU workers. + The maximum training throughput reached is ~30k transitions per second (~120k environment frames per second). + +**IMPALA-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/impala/impala.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ Gradient-based ~~~~~~~~~~~~~~ @@ -97,17 +112,31 @@ Qbert 3620 ~1000 SpaceInvaders 692 ~600 ============= ======================== ============================== -Deep Deterministic Policy Gradients (DDPG) ------------------------------------------- +**A3C-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/a3c/a3c.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +Deep Deterministic Policy Gradients (DDPG, TD3) +----------------------------------------------- `[paper] `__ `[implementation] `__ -DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers, switching to AsyncGradientsOptimizer, or using Ape-X. +DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers, switching to AsyncGradientsOptimizer, or using Ape-X. The improvements from `TD3 `__ are available though not enabled by default. + +Tuned examples: `Pendulum-v0 `__, `TD3 configuration `__, `MountainCarContinuous-v0 `__, `HalfCheetah-v2 `__ -Tuned examples: `Pendulum-v0 `__, `MountainCarContinuous-v0 `__, `HalfCheetah-v2 `__ +**DDPG-specific configs** (see also `common configs `__): -Deep Q Networks (DQN, Rainbow) ------------------------------- +.. literalinclude:: ../../python/ray/rllib/agents/ddpg/ddpg.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +Deep Q Networks (DQN, Rainbow, Parametric DQN) +---------------------------------------------- `[paper] `__ `[implementation] `__ -RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow `__ are available, though not all are enabled by default. +RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow `__ are available, though not all are enabled by default. See also how to use `parametric-actions in DQN `__. Tuned examples: `PongDeterministic-v4 `__, `Rainbow configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__, `with Dueling and Double-Q `__, `with Distributional DQN `__. @@ -125,12 +154,26 @@ Qbert 3921 7968 15780 SpaceInvaders 650 1001 1025 ~500 ============= ======================== ============================= ============================== =============================== +**DQN-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/dqn/dqn.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Policy Gradients ---------------- `[paper] `__ `[implementation] `__ We include a vanilla policy gradients implementation as an example algorithm. This is usually outperformed by PPO. Tuned examples: `CartPole-v0 `__ +**PG-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/pg/pg.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Proximal Policy Optimization (PPO) ---------------------------------- `[paper] `__ `[implementation] `__ @@ -158,6 +201,13 @@ SpaceInvaders 671 944 ~800 RLlib's multi-GPU PPO scales to multiple GPUs and hundreds of CPUs on solving the Humanoid-v1 task. Here we compare against a reference MPI-based implementation. +**PPO-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/ppo/ppo.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Derivative-free ~~~~~~~~~~~~~~~ @@ -168,6 +218,13 @@ ARS is a random search method for training linear policies for continuous contro Tuned examples: `CartPole-v0 `__, `Swimmer-v2 `__ +**ARS-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/ars/ars.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Evolution Strategies -------------------- `[paper] `__ `[implementation] `__ @@ -181,3 +238,10 @@ Tuned examples: `Humanoid-v1 `__): + +.. literalinclude:: ../../python/ray/rllib/agents/es/es.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index f752279cb58d3..68c160c912b05 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -17,7 +17,7 @@ Policy Evaluation Given an environment and policy graph, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. -You can also use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvalutor.as_remote()``). +You can also use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvaluator.as_remote()``). Policy Optimization ------------------- diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 6de076785707f..4f8a4c66ae4c3 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -5,27 +5,58 @@ RLlib works with several different types of environments, including `OpenAI Gym .. image:: rllib-envs.svg -In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name `__, however you can also register custom environments by name: +**Compatibility matrix**: + +============= ======================= ================== =========== ================== +Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies +============= ======================= ================== =========== ================== +A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes** +PPO **Yes** `+parametric`_ **Yes** **Yes** **Yes** +PG **Yes** `+parametric`_ **Yes** **Yes** **Yes** +IMPALA **Yes** `+parametric`_ No **Yes** **Yes** +DQN, Rainbow **Yes** `+parametric`_ No **Yes** No +DDPG, TD3 No **Yes** **Yes** No +APEX-DQN **Yes** `+parametric`_ No **Yes** No +APEX-DDPG No **Yes** **Yes** No +ES **Yes** **Yes** No No +ARS **Yes** **Yes** No No +============= ======================= ================== =========== ================== + +.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces + +You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name `__. Custom env classes must take a single ``env_config`` parameter in their constructor: .. code-block:: python import ray - from ray.tune.registry import register_env from ray.rllib.agents import ppo - def env_creator(env_config): - import gym - return gym.make("CartPole-v0") # or return your own custom env + class MyEnv(gym.Env): + def __init__(self, env_config): + self.action_space = ... + self.observation_space = ... + ... - register_env("my_env", env_creator) ray.init() - trainer = ppo.PPOAgent(env="my_env", config={ - "env_config": {}, # config to pass to env creator + trainer = ppo.PPOAgent(env=MyEnv, config={ + "env_config": {}, # config to pass to env class }) while True: print(trainer.train()) +You can also register a custom env creator function with a string name. This function must take a single ``env_config`` parameter and return an env instance: + +.. code-block:: python + + from ray.tune.registry import register_env + + def env_creator(env_config): + return MyEnv(...) # return an env instance + + register_env("my_env", env_creator) + trainer = ppo.PPOAgent(env="my_env") + Configuring Environments ------------------------ @@ -50,14 +81,14 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c OpenAI Gym ---------- -RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. +RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. Performance ~~~~~~~~~~~ There are two ways to scale experience collection with Gym environments: - 1. **Vectorization within a single process:** Though many envs can very achieve high frame rates per core, their throughput is limited in practice by policy evaluation between steps. For example, even small TensorFlow models incur a couple milliseconds of latency to evaluate. This can be worked around by creating multiple envs per process and batching policy evaluations across these envs. + 1. **Vectorization within a single process:** Though many envs can achieve high frame rates per core, their throughput is limited in practice by policy evaluation between steps. For example, even small TensorFlow models incur a couple milliseconds of latency to evaluate. This can be worked around by creating multiple envs per process and batching policy evaluations across these envs. You can configure ``{"num_envs_per_worker": M}`` to have RLlib create ``M`` concurrent environments per worker. RLlib auto-vectorizes Gym environments via `VectorEnv.wrap() `__. @@ -76,6 +107,10 @@ RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs_per_wo Multi-Agent ----------- +.. note:: + + Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post `__. + A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure: .. image:: multi-agent.svg @@ -132,25 +167,93 @@ If all the agents will be using the same algorithm class to train, then you can RLlib will create three distinct policies and route agent decisions to its bound policy. When an agent first appears in the env, ``policy_mapping_fn`` will be called to determine which policy it is bound to. RLlib reports separate training statistics for each policy in the return from ``train()``, along with the combined reward. -Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For how to use multiple training methods at once (here DQN and PPO), see the `two-trainer example `__. +Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For how to use multiple training methods at once (here DQN and PPO), see the `two-trainer example `__. Metrics are reported for each policy separately, for example: + +.. code-block:: bash + :emphasize-lines: 6,14,22 + + Result for PPO_multi_cartpole_0: + episode_len_mean: 34.025862068965516 + episode_reward_max: 159.0 + episode_reward_mean: 86.06896551724138 + info: + policy_0: + cur_lr: 4.999999873689376e-05 + entropy: 0.6833480000495911 + kl: 0.010264254175126553 + policy_loss: -11.95590591430664 + total_loss: 197.7039794921875 + vf_explained_var: 0.0010995268821716309 + vf_loss: 209.6578826904297 + policy_1: + cur_lr: 4.999999873689376e-05 + entropy: 0.6827034950256348 + kl: 0.01119876280426979 + policy_loss: -8.787769317626953 + total_loss: 88.26161193847656 + vf_explained_var: 0.0005457401275634766 + vf_loss: 97.0471420288086 + policy_reward_mean: + policy_0: 21.194444444444443 + policy_1: 21.798387096774192 To scale to hundreds of agents, MultiAgentEnv batches policy evaluations across multiple agents internally. It can also be auto-vectorized by setting ``num_envs_per_worker > 1``. -Agent-Driven ------------- +Variable-Sharing Between Policies +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. +RLlib will create each policy's model in a separate ``tf.variable_scope``. However, variables can still be shared between policies by explicitly entering a globally shared variable scope with ``tf.VariableScope(reuse=tf.AUTO_REUSE)``: -RLlib provides the `ServingEnv `__ class for this purpose. Unlike other envs, ServingEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards via ``self.log_returns()``. This can be done for multiple concurrent episodes as well. +.. code-block:: python + + with tf.variable_scope( + tf.VariableScope(tf.AUTO_REUSE, "name_of_global_shared_scope"), + reuse=tf.AUTO_REUSE, + auxiliary_name_scope=False): + + +There is a full example of this in the `example training script `__. + +Implementing a Centralized Critic +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Implementing a centralized critic that takes as input the observations and actions of other concurrent agents requires the definition of custom policy graphs. It can be done as follows: + +1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. The batch of critic predictions can then be added to the postprocessed trajectory. Here's an example: + +.. code-block:: python -For example, ServingEnv can be used to implement a simple REST policy `server `__ that learns over time using RLlib. In this example RLlib runs with ``num_workers=0`` to avoid port allocation issues, but in principle this could be scaled by increasing ``num_workers``. + def postprocess_trajectory(self, sample_batch, other_agent_batches, episode): + agents = ["agent_1", "agent_2", "agent_3"] # simple example of 3 agents + global_obs_batch = np.stack( + [other_agent_batches[agent_id][1]["obs"] for agent_id in agents], + axis=1) + # add the global obs and global critic value + sample_batch["global_obs"] = global_obs_batch + sample_batch["central_vf"] = self.sess.run( + self.critic_network, feed_dict={"obs": global_obs_batch}) + return sample_batch -Offline Data -~~~~~~~~~~~~ +2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy graph, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicyGraph example `__. -ServingEnv also provides a ``self.log_action()`` call to support off-policy actions. This allows the client to make independent decisions, e.g., to compare two different policies, and for RLlib to still learn from those off-policy actions. Note that this requires the algorithm used to support learning from off-policy decisions (e.g., DQN). +Interfacing with External Agents +-------------------------------- + +In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. This case also naturally arises with **external simulators** that run independently outside the control of RLlib, but may still want to leverage RLlib for training. + +RLlib provides the `ExternalEnv `__ class for this purpose. Unlike other envs, ExternalEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards via ``self.log_returns()``. This can be done for multiple concurrent episodes as well. + +ExternalEnv can be used to implement a simple REST policy `server `__ that learns over time using RLlib. In this example RLlib runs with ``num_workers=0`` to avoid port allocation issues, but in principle this could be scaled by increasing ``num_workers``. + +Logging off-policy actions +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ExternalEnv also provides a ``self.log_action()`` call to support off-policy actions. This allows the client to make independent decisions, e.g., to compare two different policies, and for RLlib to still learn from those off-policy actions. Note that this requires the algorithm used to support learning from off-policy decisions (e.g., DQN). + +Data ingest +~~~~~~~~~~~ -The ``log_action`` API of ServingEnv can be used to ingest data from offline logs. The pattern would be as follows: First, some policy is followed to produce experience data which is stored in some offline storage system. Then, RLlib creates a number of workers that use a ServingEnv to read the logs in parallel and ingest the experiences. After a round of training completes, the new policy can be deployed to collect more experiences. +The ``log_action`` API of ExternalEnv can be used to ingest data from offline logs. The pattern would be as follows: First, some policy is followed to produce experience data which is stored in some offline storage system. Then, RLlib creates a number of workers that use a ExternalEnv to read the logs in parallel and ingest the experiences. After a round of training completes, the new policy can be deployed to collect more experiences. Note that envs can read from different partitions of the logs based on the ``worker_index`` attribute of the `env context `__ passed into the environment constructor. diff --git a/doc/source/rllib-envs.svg b/doc/source/rllib-envs.svg index 37d6d66e6e1e5..2cc45dbf96fa7 100644 --- a/doc/source/rllib-envs.svg +++ b/doc/source/rllib-envs.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index a234ba0022420..9e7070b66c489 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -13,15 +13,24 @@ Built-in Models and Preprocessors RLlib picks default models based on a simple heuristic: a `vision network `__ for image observations, and a `fully connected network `__ for everything else. These models can be configured via the ``model`` config key, documented in the model `catalog `__. Note that you'll probably have to configure ``conv_filters`` if your environment observations have custom sizes, e.g., ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations. -In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell `__. More generally, RLlib supports the use of recurrent models for its algorithms (A3C, PG out of the box), and RNN support is built into its policy evaluation utilities. +In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell `__. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities. -For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple observations flattened (there isn't native tuple support yet, but you can reshape the flattened observation in a custom model). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. +For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple and Dict observations flattened (these are unflattened and accessible via the ``input_dict`` parameter in custom models). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. +Built-in Model Parameters +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following is a list of the built-in model hyperparameters: + +.. literalinclude:: ../../python/ray/rllib/models/catalog.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ Custom Models ------------- -Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers`` method. This method takes in a tensor input (observation), and returns a feature layer and float vector of the specified output size. The model can then be registered and used in place of a built-in model: +Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. A self-supervised loss can be defined via the ``loss`` method. The model can then be registered and used in place of a built-in model: .. code-block:: python @@ -30,12 +39,66 @@ Custom models should subclass the common RLlib `model class >> print(input_dict) + {'prev_actions': , + 'prev_rewards': , + 'is_training': , + 'obs': OrderedDict([ + ('sensors', OrderedDict([ + ('front_cam', [ + , + ]), + ('position', ), + ('velocity', )]))])} + """ + + layer1 = slim.fully_connected(input_dict["obs"], 64, ...) + layer2 = slim.fully_connected(layer1, 64, ...) ... return layerN, layerN_minus_1 + def value_function(self): + """Builds the value function output. + + This method can be overridden to customize the implementation of the + value function (e.g., not sharing hidden layers). + + Returns: + Tensor of size [BATCH_SIZE] for the value function. + """ + return tf.reshape( + linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1]) + + def loss(self): + """Builds any built-in (self-supervised) loss for the model. + + For example, this can be used to incorporate auto-encoder style losses. + Note that this loss has to be included in the policy graph loss to have + an effect (done for built-in algorithms). + + Returns: + Scalar tensor for the self-supervised loss. + """ + return tf.constant(0.0) + ModelCatalog.register_custom_model("my_model", MyModelClass) ray.init() @@ -46,12 +109,53 @@ Custom models should subclass the common RLlib `model class `__ and associated `training scripts `__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements. +For a full example of a custom model in code, see the `Carla RLlib model `__ and associated `training scripts `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. + +Custom Recurrent Models +~~~~~~~~~~~~~~~~~~~~~~~ + +Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py `__ model as an example to implement your own model: + +.. code-block:: python + + class MyCustomLSTM(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + # Some initial layers to process inputs, shape [BATCH, OBS...]. + features = some_hidden_layers(input_dict["obs"]) + + # Add back the nested time dimension for tf.dynamic_rnn, new shape + # will be [BATCH, MAX_SEQ_LEN, OBS...]. + last_layer = add_time_dimension(features, self.seq_lens) + + # Setup the LSTM cell (see lstm.py for an example) + lstm = rnn.BasicLSTMCell(256, state_is_tuple=True) + self.state_init = ... + self.state_in = ... + lstm_out, lstm_state = tf.nn.dynamic_rnn( + lstm, + last_layer, + initial_state=..., + sequence_length=self.seq_lens, + time_major=False, + dtype=tf.float32) + self.state_out = list(lstm_state) + + # Drop the time dimension again so back to shape [BATCH, OBS...]. + # Note that we retain the zero padding (see issue #2992). + last_layer = tf.reshape(lstm_out, [-1, cell_size]) + logits = linear(last_layer, num_outputs, "action", + normc_initializer(0.01)) + return logits, last_layer + +Batch Normalization +~~~~~~~~~~~~~~~~~~~ + +You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy_graph.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). Custom Preprocessors -------------------- -Similarly, custom preprocessors should subclass the RLlib `preprocessor class `__ and be registered in the model catalog: +Similarly, custom preprocessors should subclass the RLlib `preprocessor class `__ and be registered in the model catalog. Note that you can alternatively use `gym wrapper classes `__ around your environment instead of preprocessors. .. code-block:: python @@ -60,8 +164,8 @@ Similarly, custom preprocessors should subclass the RLlib `preprocessor class `__ and `Horizon `__. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the ``a`` in ``Q(s, a)`` becomes just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: + +1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number: + +.. code-block:: python + + class MyParamActionEnv(gym.Env): + def __init__(self, max_avail_actions): + self.action_space = Discrete(max_avail_actions) + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(max_avail_actions, )), + "avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)), + "real_obs": ..., + }) + +2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero: + +.. code-block:: python + + class MyParamActionModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + avail_actions = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] + + output = FullyConnectedNetwork( + input_dict["obs"]["real_obs"], num_outputs=action_embedding_sz) + + # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the + # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. + intent_vector = tf.expand_dims(output, 1) + + # Shape of logits is [BATCH, MAX_ACTIONS]. + action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN and several policy gradient algorithms. + Model-Based Rollouts -------------------- @@ -137,7 +288,8 @@ With a custom policy graph, you can also perform model-based rollouts and option def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): # compute a batch of actions based on the current obs_batch # and state of each episode (i.e., for multiagent). You can do diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 25cd0d8931850..dc350d272c99e 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -10,11 +10,11 @@ be trained, checkpointed, or an action computed. .. image:: rllib-api.svg -You can train a simple DQN agent with the following command +You can train a simple DQN agent with the following command: .. code-block:: bash - python ray/python/ray/rllib/train.py --run DQN --env CartPole-v0 + rllib train --run DQN --env CartPole-v0 By default, the results will be logged to a subdirectory of ``~/ray_results``. This subdirectory will contain a file ``params.json`` which contains the @@ -26,10 +26,12 @@ training process with TensorBoard by running tensorboard --logdir=~/ray_results -The ``train.py`` script has a number of options you can show by running +The ``rllib train`` command (same as the ``train.py`` script in the repo) has a number of options you can show by running: .. code-block:: bash + rllib train --help + -or- python ray/python/ray/rllib/train.py --help The most important options are for choosing the environment @@ -37,46 +39,57 @@ with ``--env`` (any OpenAI gym environment including ones registered by the user can be used) and for choosing the algorithm with ``--run`` (available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``APEX``, and ``APEX_DDPG``). +Evaluating Trained Agents +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to save checkpoints from which to evaluate agents, +set ``--checkpoint-freq`` (number of training iterations between checkpoints) +when running ``rllib train``. + + +An example of evaluating a previously trained DQN agent is as follows: + +.. code-block:: bash + + rllib rollout \ + ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \ + --run DQN --env CartPole-v0 --steps 10000 + +The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint +located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1`` +and renders its behavior in the environment specified by ``--env``. + +Configuration +------------- + Specifying Parameters ~~~~~~~~~~~~~~~~~~~~~ Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters `__. See the `algorithms documentation `__ for more information. -In an example below, we train A2C by specifying 8 workers through the config flag. We also set ``"monitor": true`` to save episode videos to the result dir: +In an example below, we train A2C by specifying 8 workers through the config flag. .. code-block:: bash - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 8, "monitor": true}' - -.. image:: rllib-config.svg + rllib train --env=PongDeterministic-v4 --run=A2C --config '{"num_workers": 8}' Specifying Resources ~~~~~~~~~~~~~~~~~~~~ -You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. Many agents also provide a ``num_gpus`` or ``gpu`` option. In addition, you can allocate a fraction of a GPU by setting ``gpu_fraction: f``. For example, with DQN you can pack five agents onto one GPU by setting ``gpu_fraction: 0.2``. Note that fractional GPU support requires enabling the experimental Xray backend by setting the environment variable ``RAY_USE_XRAY=1``. ->>>>>>> 01b030bd57f014386aa5e4c67a2e069938528abb - -Evaluating Trained Agents -~~~~~~~~~~~~~~~~~~~~~~~~~ - -In order to save checkpoints from which to evaluate agents, -set ``--checkpoint-freq`` (number of training iterations between checkpoints) -when running ``train.py``. - +You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. Note that in Ray < 0.6.0 fractional GPU support requires setting the environment variable ``RAY_USE_XRAY=1``. -An example of evaluating a previously trained DQN agent is as follows: +.. image:: rllib-config.svg -.. code-block:: bash +Common Parameters +~~~~~~~~~~~~~~~~~ - python ray/python/ray/rllib/rollout.py \ - ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1 \ - --run DQN --env CartPole-v0 +The following is a list of the common agent hyperparameters: -The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint -located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1`` -and renders its behavior in the environment specified by ``--env``. +.. literalinclude:: ../../python/ray/rllib/agents/agent.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ Tuned Examples ~~~~~~~~~~~~~~ @@ -86,16 +99,16 @@ Some good hyperparameters and settings are available in (some of them are tuned to run on GPUs). If you find better settings or tune an algorithm on a different domain, consider submitting a Pull Request! -You can run these with the ``train.py`` script as follows: +You can run these with the ``rllib train`` command as follows: .. code-block:: bash - python ray/python/ray/rllib/train.py -f /path/to/tuned/example.yaml + rllib train -f /path/to/tuned/example.yaml Python API ---------- -The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use custom environments, preprocesors, or models with RLlib. +The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use `custom environments, preprocessors, or models `__ with RLlib. Here is an example of the basic usage: @@ -155,7 +168,7 @@ Tune will schedule the trials to run in parallel on your Ray cluster: == Status == Using FIFO scheduling algorithm. Resources requested: 4/4 CPUs, 0/0 GPUs - Result logdir: /home/eric/ray_results/my_experiment + Result logdir: ~/ray_results/my_experiment PENDING trials: - PPO_CartPole-v0_2_sgd_stepsize=0.0001: PENDING RUNNING trials: @@ -184,11 +197,194 @@ You can also access just the "master" copy of the agent state through ``agent.lo agent.optimizer.foreach_evaluator_with_index( lambda ev, i: ev.for_policy(lambda p: p.get_weights())) +Global Coordination +~~~~~~~~~~~~~~~~~~~ +Sometimes, it is necessary to coordinate between pieces of code that live in different processes managed by RLlib. For example, it can be useful to maintain a global average of a certain variable, or centrally control a hyperparameter used by policies. Ray provides a general way to achieve this through *named actors* (learn more about Ray actors `here `__). As an example, consider maintaining a shared global counter that is incremented by environments and read periodically from your driver program: + +.. code-block:: python + + from ray.experimental import named_actors + + @ray.remote + class Counter: + def __init__(self): + self.count = 0 + def inc(self, n): + self.count += n + def get(self): + return self.count + + # on the driver + counter = Counter.remote() + named_actors.register_actor("global_counter", counter) + print(ray.get(counter.get.remote())) # get the latest count + + # in your envs + counter = named_actors.get_actor("global_counter") + counter.inc.remote(1) # async call to increment the global count + +Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce. + +Callbacks and Custom Metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode `__. Custom state can be stored for the `episode `__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. The following example (full code `here `__) logs a custom metric from the environment: + +.. code-block:: python + + def on_episode_start(info): + print(info.keys()) # -> "env", 'episode" + episode = info["episode"] + print("episode {} started".format(episode.episode_id)) + episode.user_data["pole_angles"] = [] + + def on_episode_step(info): + episode = info["episode"] + pole_angle = abs(episode.last_observation_for()[2]) + episode.user_data["pole_angles"].append(pole_angle) + + def on_episode_end(info): + episode = info["episode"] + pole_angle = np.mean(episode.user_data["pole_angles"]) + print("episode {} ended with length {} and pole angles {}".format( + episode.episode_id, episode.length, pole_angle)) + episode.custom_metrics["pole_angle"] = pole_angle + + def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"].__name__, info["result"]["episodes_this_iter"])) + + ray.init() + trials = tune.run_experiments({ + "test": { + "env": "CartPole-v0", + "run": "PG", + "config": { + "callbacks": { + "on_episode_start": tune.function(on_episode_start), + "on_episode_step": tune.function(on_episode_step), + "on_episode_end": tune.function(on_episode_end), + "on_train_result": tune.function(on_train_result), + }, + }, + } + }) + +Custom metrics can be accessed and visualized like any other training result: + +.. image:: custom_metric.png + +Example: Curriculum Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let's look at two ways to use the above APIs to implement `curriculum learning `__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time: + +Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function: + +.. code-block:: python + + import ray + from ray import tune + from ray.rllib.agents.ppo import PPOAgent + + def train(config, reporter): + agent = PPOAgent(config=config, env=YourEnv) + while True: + result = agent.train() + reporter(**result) + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": train, + "config": { + "num_gpus": 0, + "num_workers": 2, + }, + "trial_resources": { + "cpu": 1, + "gpu": lambda spec: spec.config.num_gpus, + "extra_cpu": lambda spec: spec.config.num_workers, + }, + }, + }) + +Approach 2: Use the callbacks API to update the environment on new training results: + +.. code-block:: python + + import ray + from ray import tune + + def on_train_result(info): + result = info["result"] + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent = info["agent"] + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": "PPO", + "env": YourEnv, + "config": { + "callbacks": { + "on_train_result": tune.function(on_train_result), + }, + }, + }, + }) + +Debugging +--------- + +Gym Monitor +~~~~~~~~~~~ + +The ``"monitor": true`` config can be used to save Gym episode videos to the result dir. For example: + +.. code-block:: bash + + rllib train --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "monitor": true}' + + # videos will be saved in the ~/ray_results/ dir, for example + openaigym.video.0.31401.video000000.meta.json + openaigym.video.0.31401.video000000.mp4 + openaigym.video.0.31403.video000000.meta.json + openaigym.video.0.31403.video000000.mp4 + +Log Verbosity +~~~~~~~~~~~~~ + +You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: + +.. code-block:: bash + + rllib train --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}' + +Stack Traces +~~~~~~~~~~~~ + +You can use the ``ray stack`` command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues. REST API -------- -In some cases (i.e., when interacting with an external environment) it makes more sense to interact with RLlib as if were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib's serving env `interface `__. +In some cases (i.e., when interacting with an externally hosted simulator or production environment) it makes more sense to interact with RLlib as if were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib's external agents `interface `__. .. autoclass:: ray.rllib.utils.policy_client.PolicyClient :members: diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index ea5bbbf583810..e96bd6fccbcb9 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -10,14 +10,14 @@ Learn more about RLlib's design by reading the `ICML paper `__ or `TensorFlow `__. Then, install the Ray RLlib module: +RLlib has extra dependencies on top of ``ray``. First, you'll need to install either `PyTorch `__ or `TensorFlow `__. Then, install the RLlib module: .. code-block:: bash pip install tensorflow # or tensorflow-gpu - pip install ray[rllib] + pip install ray[rllib] # also recommended: ray[debug] -You might also want to clone the Ray repo for convenient access to RLlib helper scripts: +You might also want to clone the `Ray repo `__ for convenient access to RLlib helper scripts: .. code-block:: bash @@ -27,7 +27,9 @@ You might also want to clone the Ray repo for convenient access to RLlib helper Training APIs ------------- * `Command-line `__ +* `Configuration `__ * `Python API `__ +* `Debugging `__ * `REST API `__ Environments @@ -36,8 +38,7 @@ Environments * `OpenAI Gym `__ * `Vectorized `__ * `Multi-Agent `__ -* `Agent-Driven `__ -* `Offline Data Ingest `__ +* `Interfacing with External Agents `__ * `Batch Asynchronous `__ Algorithms @@ -53,9 +54,9 @@ Algorithms - `Advantage Actor-Critic (A2C, A3C) `__ - - `Deep Deterministic Policy Gradients (DDPG) `__ + - `Deep Deterministic Policy Gradients (DDPG, TD3) `__ - - `Deep Q Networks (DQN, Rainbow) `__ + - `Deep Q Networks (DQN, Rainbow, Parametric DQN) `__ - `Policy Gradients `__ @@ -74,6 +75,7 @@ Models and Preprocessors * `Custom Models `__ * `Custom Preprocessors `__ * `Customizing Policy Graphs `__ +* `Variable-length / Parametric Action Spaces `__ * `Model-Based Rollouts `__ RLlib Concepts @@ -98,3 +100,6 @@ If you encounter errors like `blas_thread_init: pthread_create: Resource temporarily unavailable` when using many workers, try setting ``OMP_NUM_THREADS=1``. Similarly, check configured system limits with `ulimit -a` for other resource limit errors. + +For debugging unexpected hangs or performance problems, you can run ``ray stack`` to dump +the stack traces of all Ray workers on the current node. This requires py-spy to be installed. diff --git a/doc/source/security.rst b/doc/source/security.rst new file mode 100644 index 0000000000000..6b636c66858e2 --- /dev/null +++ b/doc/source/security.rst @@ -0,0 +1,55 @@ +Security +======== + +This document describes best security practices for using Ray. + +Intended Use and Threat Model +----------------------------- + +Ray instances should run on a secure network without public facing ports. +The most common threat for Ray instances is unauthorized access to Redis, +which can be exploited to gain shell access and run arbitray code. +The best fix is to run Ray instances on a secure, trusted network. + +Running Ray on a secured network is not always feasible, so Ray +provides some basic security features: + + +Redis Port Authentication +------------------------- + +To prevent exploits via unauthorized Redis access, Ray provides the option to +password-protect Redis ports. While this is not a replacement for running Ray +behind a firewall, this feature is useful for instances exposed to the internet +where configuring a firewall is not possible. Because Redis is +very fast at serving queries, the chosen password should be long. + +Redis authentication is only supported on the raylet code path. + +To add authentication via the Python API, start Ray using: + +.. code-block:: python + + ray.init(redis_password="password") + +To add authentication via the CLI, or connect to an existing Ray instance with +password-protected Redis ports: + +.. code-block:: bash + + ray start [--head] --redis-password="password" + +While Redis port authentication may protect against external attackers, +Ray does not encrypt traffic between nodes so man-in-the-middle attacks are +possible for clusters on untrusted networks. + +Cloud Security +-------------- + +Launching Ray clusters on AWS or GCP using the ``ray up`` command +automatically configures security groups that prevent external Redis access. + +References +---------- + +- The `Redis security documentation ` diff --git a/doc/source/sgd.png b/doc/source/sgd.png new file mode 100644 index 0000000000000..aed38161cb159 Binary files /dev/null and b/doc/source/sgd.png differ diff --git a/doc/source/tempfile.rst b/doc/source/tempfile.rst new file mode 100644 index 0000000000000..d68e835e0261b --- /dev/null +++ b/doc/source/tempfile.rst @@ -0,0 +1,86 @@ +Temporary Files +=============== + +Ray will produce some temporary files during running. +They are useful for logging, debugging & sharing object store with other programs. + +Location of Temporary Files +--------------------------- + +First we introduce the concept of a session of Ray. + +A session contains a set of processes. A session is created by executing +``ray start`` command or call ``ray.init()`` in a Python script and ended by +executing ``ray stop`` or call ``ray.shutdown()``. + +For each session, Ray will create a *root temporary directory* to place all its +temporary files. The path is ``/tmp/ray/session_{datetime}_{pid}`` by default. +The pid belongs to the startup process (the process calling ``ray.init()`` or +the Ray process executed by a shell in ``ray start``). +You can sort by their names to find the latest session. + +You are allowed to change the *root temporary directory* in one of these ways: + +* Pass ``--temp-dir={your temp path}`` to ``ray start`` +* Specify ``temp_dir`` when call ``ray.init()`` + +You can also use ``default_worker.py --temp-dir={your temp path}`` to +start a new worker with given *root temporary directory*. + +The *root temporary directory* you specified will be given as it is, +without pids or datetime attached. + +Layout of Temporary Files +------------------------- + +A typical layout of temporary files could look like this: + +.. code-block:: text + + /tmp + └── ray + └── session_{datetime}_{pid} + ├── logs # for logging + │   ├── log_monitor.err + │   ├── log_monitor.out + │   ├── monitor.err + │   ├── monitor.out + │   ├── plasma_store_0.err # array of plasma stores' outputs + │   ├── plasma_store_0.out + │   ├── raylet_0.err # array of raylets' outputs. Control it with `--no-redirect-worker-output` (in Ray's command line) or `redirect_worker_output` (in ray.init()) + │   ├── raylet_0.out + │   ├── redis-shard_0.err # array of redis shards' outputs + │   ├── redis-shard_0.out + │   ├── redis.err # redis + │   ├── redis.out + │   ├── webui.err # ipython notebook web ui + │   ├── webui.out + │   ├── worker-{worker_id}.err # redirected output of workers + │   ├── worker-{worker_id}.out + │   └── {other workers} + ├── ray_ui.ipynb # ipython notebook file + └── sockets # for sockets + ├── plasma_store + └── raylet # this could be deleted by Ray's shutdown cleanup. + + +Plasma Object Store Socket +-------------------------- + +Plasma object store sockets can be used to share objects with other programs using Apache Arrow. + +You are allowed to specify the plasma object store socket in one of these ways: + +* Pass ``--plasma-store-socket-name={your socket path}`` to ``ray start`` +* Specify ``plasma_store_socket_name`` when call ``ray.init()`` + +The path you specified will be given as it is without being affected any other paths. + + +Notes +----- + +Temporary file policies are defined in ``python/ray/tempfile_services.py``. + +Currently, we keep ``/tmp/ray`` as the default directory for temporary data files of RLlib as before. +It is not very reasonable and could be changed later. diff --git a/doc/source/troubleshooting.rst b/doc/source/troubleshooting.rst index ff4b3039e8c15..86f56e7755957 100644 --- a/doc/source/troubleshooting.rst +++ b/doc/source/troubleshooting.rst @@ -61,10 +61,10 @@ of the following reasons. - **Stressful workloads:** Workloads that create many many tasks in a short amount of time can sometimes interfere with the heartbeat mechanism that we use to check that processes are still alive. On the head node in the cluster, - you can check the files ``/tmp/raylogs/monitor-******.out`` and - ``/tmp/raylogs/monitor-******.err``. They will indicate which processes Ray - has marked as dead (due to a lack of heartbeats). However, it is currently - possible for a process to get marked as dead without actually having died. + you can check the files ``/tmp/ray/session_*/logs/monitor*``. They will + indicate which processes Ray has marked as dead (due to a lack of heartbeats). + However, it is currently possible for a process to get marked as dead without + actually having died. - **Starting many actors:** Workloads that start a large number of actors all at once may exhibit problems when the processes (or libraries that they use) @@ -92,6 +92,11 @@ of the following reasons. Hanging ------- +.. tip:: + + You can run ``ray stack`` to dump the stack traces of all Ray workers on + the current node. This requires py-spy to be installed. + If a workload is hanging and not progressing, the problem may be one of the following. diff --git a/doc/source/tune-examples.rst b/doc/source/tune-examples.rst new file mode 100644 index 0000000000000..e0af86bcb6956 --- /dev/null +++ b/doc/source/tune-examples.rst @@ -0,0 +1,62 @@ +Tune Examples +============= + +.. Keep this in sync with ray/python/ray/tune/examples/README.rst + +In our repository, we provide a variety of examples for the various use cases and features of Tune. + +If any example is broken, or if you'd like to add an example to this page, feel free to raise an issue on our Github repository. + + +General Examples +---------------- + +- `async_hyperband_example `__: + Example of using a Trainable class with AsyncHyperBandScheduler. +- `hyperband_example `__: + Example of using a Trainable class with HyperBandScheduler. Also uses the Experiment class API for specifying the experiment configuration. +- `hyperopt_example `__: + Optimizes a basic function using the function-based API and the HyperOptSearch (SearchAlgorithm wrapper for HyperOpt TPE). + Also uses the AsyncHyperBandScheduler. +- `pbt_example `__: + Example of using a Trainable class with PopulationBasedTraining scheduler. +- `pbt_ppo_example `__: + Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. + + +Keras Examples +-------------- + +- `tune_mnist_keras `__: + Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune. + + +PyTorch Examples +---------------- + +- `mnist_pytorch `__: + Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune. +- `mnist_pytorch_trainable `__: + Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end. + + +TensorFlow Examples +------------------- + +- `tune_mnist_ray `__: + A basic example of tuning a TensorFlow model on MNIST using the Trainable class. +- `tune_mnist_ray_hyperband `__: + A basic example of tuning a TensorFlow model on MNIST using the Trainable class and the HyperBand scheduler. +- `tune_mnist_async_hyperband `__: + Example of tuning a TensorFlow model on MNIST using AsyncHyperBand. + + +Contributed Examples +-------------------- + +- `pbt_tune_cifar10_with_keras `__: + A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler. +- `genetic_example `__: + Optimizing the michalewicz function using the contributed GeneticSearch search algorithm with AsyncHyperBandScheduler. + + diff --git a/doc/source/tune-package-ref.rst b/doc/source/tune-package-ref.rst index d6f13cd981556..e7f3d3167adab 100644 --- a/doc/source/tune-package-ref.rst +++ b/doc/source/tune-package-ref.rst @@ -12,6 +12,11 @@ ray.tune :members: :private-members: + +.. autoclass:: ray.tune.function_runner.StatusReporter + :members: __call__ + + ray.tune.schedulers ------------------- @@ -24,7 +29,7 @@ ray.tune.suggest .. automodule:: ray.tune.suggest :members: - :exclude-members: function, grid_search, SuggestionAlgorithm + :exclude-members: function, sample_from, grid_search, SuggestionAlgorithm :show-inheritance: .. autoclass:: ray.tune.suggest.SuggestionAlgorithm diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 97e8ce1bc295c..e8e5b0fa672ef 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -25,10 +25,13 @@ By default, Tune uses the `default search space and variant generation process < :noindex: +Note that other search algorithms will not necessarily extend this class and may require a different search space declaration than the default Tune format. + HyperOpt Search (Tree-structured Parzen Estimators) --------------------------------------------------- -The ``HyperOptSearch`` is a SearchAlgorithm that is backed by `HyperOpt `__ to perform sequential model-based hyperparameter optimization. +The ``HyperOptSearch`` is a SearchAlgorithm that is backed by `HyperOpt `__ to perform sequential model-based hyperparameter optimization. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune's default variant generation/search space declaration when using HyperOptSearch. + In order to use this search algorithm, you will need to install HyperOpt via the following command: .. code-block:: bash @@ -47,7 +50,6 @@ An example of this can be found in `hyperopt_example.py `__. - -More information about Tune's `trial schedulers can be found here `__. - +More information about Tune's `search algorithms can be found here `__. More information about Tune's `trial schedulers can be found here `__. Start by installing, importing, and initializing Ray. @@ -22,29 +19,48 @@ Start by installing, importing, and initializing Ray. ray.init() -Tune provides a ``run_experiments`` function that generates and runs the trials as described by the `experiment specification `__. -.. autofunction:: ray.tune.run_experiments - :noindex: +Experiment Configuration +------------------------ -This function will report status on the command line until all Trials stop: +This section will cover the main steps needed to modify your code to run Tune: using the `Training API `__ and `executing your Tune experiment `__. -:: +You can checkout out our `examples page `__ for more code examples. - == Status == - Using FIFO scheduling algorithm. - Resources used: 4/8 CPUs, 0/0 GPUs - Result logdir: ~/ray_results/my_experiment - - train_func_0_lr=0.2,momentum=1: RUNNING [pid=6778], 209 s, 20604 ts, 7.29 acc - - train_func_1_lr=0.4,momentum=1: RUNNING [pid=6780], 208 s, 20522 ts, 53.1 acc - - train_func_2_lr=0.6,momentum=1: TERMINATED [pid=6789], 21 s, 2190 ts, 100 acc - - train_func_3_lr=0.2,momentum=2: RUNNING [pid=6791], 208 s, 41004 ts, 8.37 acc - - train_func_4_lr=0.4,momentum=2: RUNNING [pid=6800], 209 s, 41204 ts, 70.1 acc - - train_func_5_lr=0.6,momentum=2: TERMINATED [pid=6809], 10 s, 2164 ts, 100 acc +Training API +~~~~~~~~~~~~ +Training can be done with either the **function-based API** or **Trainable API**. + +**Python functions** will need to have the following signature: + +.. code-block:: python + + def trainable(config, reporter): + """ + Args: + config (dict): Parameters provided from the search algorithm + or variant generation. + reporter (Reporter): Handle to report intermediate metrics to Tune. + """ + + while True: + # ... + reporter(**kwargs) + +The reporter will allow you to report metrics used for scheduling, search, or early stopping. + +Tune will run this function on a separate thread in a Ray actor process. Note that this API is not checkpointable, since the thread will never return control back to its caller. The reporter documentation can be `found here `__. + +.. note:: + If you have a lambda function that you want to train, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``. + +**Python classes** passed into Tune will need to subclass ``ray.tune.Trainable``. The Trainable interface `can be found here `__. + +Both the Trainable and function-based API will have `autofilled metrics `__ in addition to the metrics reported. + +See the `experiment specification `__ section on how to specify and execute your training. -Experiment Configuration ------------------------- Specifying Experiments ~~~~~~~~~~~~~~~~~~~~~~ @@ -79,54 +95,33 @@ dictionary. Tune will convert the dict into an ``ray.tune.Experiment`` object. "max_failures": 2 } } - run_experiments(experiment_spec) - - -An example of this can be found in `async_hyperband_example.py `__. - -Model API -~~~~~~~~~ - -You can either pass in a Python function or Python class for model training as follows, each requiring a specific signature/interface: - -.. code-block:: python - :emphasize-lines: 3,8 - - experiment_spec = { - "my_experiment_name": { - "run": my_trainable - } - } - # or with the Experiment API - experiment_spec = Experiment("my_experiment_name", my_trainable) +Tune provides a ``run_experiments`` function that generates and runs the trials. - run_experiments(experiments=experiment_spec) - - -**Python functions** will need to have the following signature: +.. autofunction:: ray.tune.run_experiments + :noindex: -.. code-block:: python +This function will report status on the command line until all Trials stop: - def trainable(config, reporter): - """ - Args: - config (dict): Parameters provided from the search algorithm - or variant generation. - reporter (Reporter): Handle to report intermediate metrics to Tune. - """ +:: -Tune will run this function on a separate thread in a Ray actor process. Note that trainable functions are not checkpointable, since they never return control back to their caller. See `Trial Checkpointing for more details `__. + == Status == + Using FIFO scheduling algorithm. + Resources used: 4/8 CPUs, 0/0 GPUs + Result logdir: ~/ray_results/my_experiment + - train_func_0_lr=0.2,momentum=1: RUNNING [pid=6778], 209 s, 20604 ts, 7.29 acc + - train_func_1_lr=0.4,momentum=1: RUNNING [pid=6780], 208 s, 20522 ts, 53.1 acc + - train_func_2_lr=0.6,momentum=1: TERMINATED [pid=6789], 21 s, 2190 ts, 100 acc + - train_func_3_lr=0.2,momentum=2: RUNNING [pid=6791], 208 s, 41004 ts, 8.37 acc + - train_func_4_lr=0.4,momentum=2: RUNNING [pid=6800], 209 s, 41204 ts, 70.1 acc + - train_func_5_lr=0.6,momentum=2: TERMINATED [pid=6809], 10 s, 2164 ts, 100 acc -.. note:: - If you have a lambda function that you want to train, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``. -**Python classes** passed into Tune will need to subclass ``ray.tune.Trainable``. +An example of this can be found in `async_hyperband_example.py `__. -.. autoclass:: ray.tune.Trainable - :members: __init__, _save, _restore, _train, _setup, _stop - :noindex: +Training Features +----------------- Tune Search Space (Default) ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -134,6 +129,9 @@ Tune Search Space (Default) You can use ``tune.grid_search`` to specify an axis of a grid search. By default, Tune also supports sampling parameters from user-specified lambda functions, which can be used independently or in combination with grid search. +.. note:: + If you specify an explicit Search Algorithm such as any SuggestionAlgorithm, you may not be able to specify lambdas or grid search with this interface, as the search algorithm may require a different search space declaration. + The following shows grid search over two nested parameters combined with random sampling from two lambda functions, generating 9 different trials. Note that the value of ``beta`` depends on the value of ``alpha``, which is represented by referencing ``spec.config.alpha`` in the lambda function. This lets you specify conditional parameter distributions. .. code-block:: python @@ -143,8 +141,8 @@ The following shows grid search over two nested parameters combined with random "my_experiment_name": { "run": my_trainable, "config": { - "alpha": lambda spec: np.random.uniform(100), - "beta": lambda spec: spec.config.alpha * np.random.normal(), + "alpha": tune.sample_from(lambda spec: np.random.uniform(100)), + "beta": tune.sample_from(lambda spec: spec.config.alpha * np.random.normal()), "nn_layers": [ tune.grid_search([16, 64, 256]), tune.grid_search([16, 64, 256]), @@ -155,10 +153,7 @@ The following shows grid search over two nested parameters combined with random .. note:: - Lambda functions will be evaluated during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it. - -.. warning:: - If you specify a Search Algorithm, you may not be able to use this feature, as the algorithm may require a different search space declaration. + Use ``tune.sample_from(...)`` to sample from a function during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it. For more information on variant generation, see `basic_variant.py `__. @@ -174,8 +169,8 @@ By default, each random variable and grid search point is sampled once. To take "my_experiment_name": { "run": my_trainable, "config": { - "alpha": lambda spec: np.random.uniform(100), - "beta": lambda spec: spec.config.alpha * np.random.normal(), + "alpha": tune.sample_from(lambda spec: np.random.uniform(100)), + "beta": tune.sample_from(lambda spec: spec.config.alpha * np.random.normal()), "nn_layers": [ tune.grid_search([16, 64, 256]), tune.grid_search([16, 64, 256]), @@ -193,9 +188,12 @@ Using GPUs (Resource Allocation) Tune will allocate the specified GPU and CPU ``trial_resources`` to each individual trial (defaulting to 1 CPU per trial). Under the hood, Tune runs each trial as a Ray actor, using Ray's resource handling to allocate resources and place actors. A trial will not be scheduled unless at least that amount of resources is available in the cluster, preventing the cluster from being overloaded. +Fractional values are also supported, (i.e., ``"gpu": 0.2``). You can find an example of this in the `Keras MNIST example `__. + If GPU resources are not requested, the ``CUDA_VISIBLE_DEVICES`` environment variable will be set as empty, disallowing GPU access. Otherwise, it will be set to the GPUs in the list (this is managed by Ray). + If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will also want to set ``extra_cpu`` or ``extra_gpu`` to reserve extra resource slots for the actors you will create. For example, if a trainable class requires 1 GPU itself, but will launch 4 actors each using another GPU, then it should set ``"gpu": 1, "extra_gpu": 4``. .. code-block:: python @@ -216,14 +214,14 @@ If your trainable function / class creates further Ray actors or tasks that also Trial Checkpointing ~~~~~~~~~~~~~~~~~~~ -To enable checkpointing, you must implement a `Trainable class `__ (Trainable functions are not checkpointable, since they never return control back to their caller). The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) `__. Implementing this interface is required to support resource multiplexing in Trial Schedulers such as HyperBand and PBT. +To enable checkpointing, you must implement a `Trainable class `__ (Trainable functions are not checkpointable, since they never return control back to their caller). The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) `__. Implementing this interface is required to support resource multiplexing in Trial Schedulers such as HyperBand and PBT. For TensorFlow model training, this would look something like this `(full tensorflow example) `__: .. code-block:: python class MyClass(Trainable): - def _setup(self): + def _setup(self, config): self.saver = tf.train.Saver() self.sess = ... self.iteration = 0 @@ -297,6 +295,28 @@ You often will want to compute a large object (e.g., training data, model weight } }) +Auto-Filled Results +------------------- + +During training, Tune will automatically fill certain fields if not already provided. All of these can be used as stopping conditions or in the Scheduler/Search Algorithm specification. + +.. literalinclude:: ../../python/ray/tune/result.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +The following fields will automatically show up on the console output, if provided: + +1. ``episode_reward_mean`` +2. ``mean_loss`` +3. ``mean_accuracy`` +4. ``timesteps_this_iter`` (aggregated into ``timesteps_total``). + +.. code-block:: bash + + Example_0: TERMINATED [pid=68248], 179 s, 2 iter, 60000 ts, 94 rew + + Logging and Visualizing Results ------------------------------- @@ -360,12 +380,6 @@ Then, on the client side, you can use the following class. The server address de For an example notebook for using the Client API, see the `Client API Example `__. -Examples --------- - -You can find a comprehensive of examples `using Tune and its various features here `__, including examples using Keras, TensorFlow, and Population-Based Training. - - Further Questions or Issues? ---------------------------- diff --git a/doc/source/tune.rst b/doc/source/tune.rst index a849f3b811d58..14c95fb0edcb0 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -7,7 +7,10 @@ Tune: Scalable Hyperparameter Search Tune is a scalable framework for hyperparameter search with a focus on deep learning and deep reinforcement learning. -You can find the code for Tune `here on GitHub `__. +You can find the code for Tune `here on GitHub `__. To get started with Tune, try going through `our tutorial of using Tune with Keras `__. + +(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. + Features -------- @@ -42,7 +45,7 @@ You'll need to first `install ray `__ to import Tune. .. code-block:: bash - pip install ray + pip install ray # also recommended: ray[debug] Quick Start diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index 0493b69169909..81de87a571ced 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -9,7 +9,7 @@ To use Ray, you need to understand the following: Overview -------- -Ray is a Python-based distributed execution engine. The same code can be run on +Ray is a distributed execution engine. The same code can be run on a single machine to achieve efficient multiprocessing, and it can be used on a cluster for large computations. @@ -21,8 +21,6 @@ When using Ray, several processes are involved. allows workers to efficiently share objects on the same node with minimal copying and deserialization. - One **local scheduler** per node assigns tasks to workers on the same node. -- A **global scheduler** receives tasks from local schedulers and assigns them - to other local schedulers. - A **driver** is the Python process that the user controls. For example, if the user is running a script or using a Python shell, then the driver is the Python process that runs the script or the shell. A driver is similar to a worker in diff --git a/doc/source/user-profiling.rst b/doc/source/user-profiling.rst index e7c18dd5ee737..cdbabff391884 100644 --- a/doc/source/user-profiling.rst +++ b/doc/source/user-profiling.rst @@ -1,11 +1,11 @@ Profiling for Ray Users ======================= -This document is intended for users of Ray who want to know how to evaluate -the performance of their code while running on Ray. Profiling the -performance of your code can be very helpful to determine performance -bottlenecks or to find out where your code may not be parallelized properly. -If you are interested in pinpointing why your Ray application may not be +This document is intended for users of Ray who want to know how to evaluate +the performance of their code while running on Ray. Profiling the +performance of your code can be very helpful to determine performance +bottlenecks or to find out where your code may not be parallelized properly. +If you are interested in pinpointing why your Ray application may not be achieving the expected speedup, read on! @@ -28,26 +28,26 @@ let's define our remote function to just sleep for 0.5 seconds: def func(): time.sleep(0.5) -In our example setup, we wish to call our remote function ``func()`` five -times, and store the result of each call into a list. To compare the -performance of different ways of looping our calls to our remote function, +In our example setup, we wish to call our remote function ``func()`` five +times, and store the result of each call into a list. To compare the +performance of different ways of looping our calls to our remote function, we can define each loop version as a separate function on the driver script. -For the first version **ex1**, each iteration of the loop calls the remote -function, then calls ``ray.get`` in an attempt to store the current result +For the first version **ex1**, each iteration of the loop calls the remote +function, then calls ``ray.get`` in an attempt to store the current result into the list, as follows: .. code-block:: python # This loop is suboptimal in Ray, and should only be used for the sake of this example - def ex1(): + def ex1(): list1 = [] for i in range(5): list1.append(ray.get(func.remote())) -For the second version **ex2**, each iteration of the loop calls the remote -function, and stores it into the list **without** calling ``ray.get`` each time. -``ray.get`` is used after the loop has finished, in preparation for processing +For the second version **ex2**, each iteration of the loop calls the remote +function, and stores it into the list **without** calling ``ray.get`` each time. +``ray.get`` is used after the loop has finished, in preparation for processing ``func()``'s results: .. code-block:: python @@ -59,8 +59,8 @@ function, and stores it into the list **without** calling ``ray.get`` each time. list2.append(func.remote()) ray.get(list2) -Finally, for an example that's not so parallelizable, let's create a -third version **ex3** where the driver has to call a local +Finally, for an example that's not so parallelizable, let's create a +third version **ex3** where the driver has to call a local function in between each call to the remote function ``func()``: .. code-block:: python @@ -81,14 +81,14 @@ Timing Performance Using Python's Timestamps -------------------------------------------- One way to sanity-check the performance of the three loops is simply to -time how long it takes to complete each loop version. We can do this using +time how long it takes to complete each loop version. We can do this using python's built-in ``time`` `module`_. .. _`module`: https://docs.python.org/3/library/time.html -The ``time`` module contains a useful ``time()`` function that returns the -current timestamp in unix time whenever it's called. We can create a generic -function wrapper to call ``time()`` right before and right after each loop +The ``time`` module contains a useful ``time()`` function that returns the +current timestamp in unix time whenever it's called. We can create a generic +function wrapper to call ``time()`` right before and right after each loop function to print out how long each loop takes overall: .. code-block:: python @@ -106,8 +106,8 @@ function to print out how long each loop takes overall: return result return timed_wrapper -To always print out how long the loop takes to run each time the loop -function ``ex1()`` is called, we can evoke our ``time_this`` wrapper with +To always print out how long the loop takes to run each time the loop +function ``ex1()`` is called, we can evoke our ``time_this`` wrapper with a function decorator. This can similarly be done to functions ``ex2()`` and ``ex3()``: @@ -136,9 +136,9 @@ Then, running the three timed loops should yield output similar to this: | func:'ex2' args:[(), {}] took: 1.0032 seconds | | func:'ex3' args:[(), {}] took: 2.0039 seconds | -Let's interpret these results. +Let's interpret these results. -Here, ``ex1()`` took substantially more time than ``ex2()``, where +Here, ``ex1()`` took substantially more time than ``ex2()``, where their only difference is that ``ex1()`` calls ``ray.get`` on the remote function before adding it to the list, while ``ex2()`` waits to fetch the entire list with ``ray.get`` at once. @@ -160,28 +160,28 @@ entire list with ``ray.get`` at once. list2.append(func.remote()) ray.get(list2) -Notice how ``ex1()`` took 2.5 seconds, exactly five times 0.5 seconds, or -the time it would take to wait for our remote function five times in a row. +Notice how ``ex1()`` took 2.5 seconds, exactly five times 0.5 seconds, or +the time it would take to wait for our remote function five times in a row. -By calling ``ray.get`` after each call to the remote function, ``ex1()`` -removes all ability to parallelize work, by forcing the driver to wait for -each ``func()``'s result in succession. We are not taking advantage of Ray -parallelization here! +By calling ``ray.get`` after each call to the remote function, ``ex1()`` +removes all ability to parallelize work, by forcing the driver to wait for +each ``func()``'s result in succession. We are not taking advantage of Ray +parallelization here! -Meanwhile, ``ex2()`` takes about 1 second, much faster than it would normally -take to call ``func()`` five times iteratively. Ray is running each call to -``func()`` in parallel, saving us time. +Meanwhile, ``ex2()`` takes about 1 second, much faster than it would normally +take to call ``func()`` five times iteratively. Ray is running each call to +``func()`` in parallel, saving us time. -``ex1()`` is actually a common user mistake in Ray. ``ray.get`` is not -necessary to do before adding the result of ``func()`` to the list. Instead, -the driver should send out all parallelizable calls to the remote function +``ex1()`` is actually a common user mistake in Ray. ``ray.get`` is not +necessary to do before adding the result of ``func()`` to the list. Instead, +the driver should send out all parallelizable calls to the remote function to Ray before waiting to receive their results with ``ray.get``. ``ex1()``'s suboptimal behavior can be noticed just using this simple timing test. -Realistically, however, many applications are not as highly parallelizable -as ``ex2()``, and the application includes sections where the code must run in +Realistically, however, many applications are not as highly parallelizable +as ``ex2()``, and the application includes sections where the code must run in serial. ``ex3()`` is such an example, where the local function ``other_func()`` -must run first before each call to ``func()`` can be submitted to Ray. +must run first before each call to ``func()`` can be submitted to Ray. .. code-block:: python @@ -196,23 +196,23 @@ must run first before each call to ``func()`` can be submitted to Ray. list2.append(func.remote()) ray.get(list3) -What results is that while ``ex3()`` still gained 0.5 seconds of speedup +What results is that while ``ex3()`` still gained 0.5 seconds of speedup compared to the completely serialized ``ex1()`` version, this speedup is -still nowhere near the ideal speedup of ``ex2()``. +still nowhere near the ideal speedup of ``ex2()``. -The dramatic speedup of ``ex2()`` is possible because ``ex2()`` is -theoretically completely parallelizable: if we were given 5 CPUs, all 5 calls -to ``func()`` can be run in parallel. What is happening with ``ex3()``, -however, is that each parallelized call to ``func()`` is staggered by a wait +The dramatic speedup of ``ex2()`` is possible because ``ex2()`` is +theoretically completely parallelizable: if we were given 5 CPUs, all 5 calls +to ``func()`` can be run in parallel. What is happening with ``ex3()``, +however, is that each parallelized call to ``func()`` is staggered by a wait of 0.3 seconds for the local ``other_func()`` to finish. -``ex3()`` is thus a manifestation of `Amdahls Law`_: the fastest theoretically -possible execution time from parallelizing an application is limited to be -no better than the time it takes to run all serial parts in serial. +``ex3()`` is thus a manifestation of `Amdahls Law`_: the fastest theoretically +possible execution time from parallelizing an application is limited to be +no better than the time it takes to run all serial parts in serial. .. _`Amdahls Law`: https://en.wikipedia.org/wiki/Amdahl%27s_law -Due to Amdahl's Law, ``ex3()`` must take at least 1.5 +Due to Amdahl's Law, ``ex3()`` must take at least 1.5 seconds -- the time it takes for 5 serial calls to ``other_func()`` to finish! After an additional 0.5 seconds to execute func and get the result, the computation is done. @@ -224,7 +224,7 @@ Profiling Using An External Profiler (Line Profiler) One way to profile the performance of our code using Ray is to use a third-party profiler such as `Line_profiler`_. Line_profiler is a useful line-by-line profiler for pure Python applications that formats its output side-by-side with -the profiled code itself. +the profiled code itself. Alternatively, another third-party profiler (not covered in this documentation) that you could use is `Pyflame`_, which can generate profiling graphs. @@ -238,11 +238,11 @@ First install ``line_profiler`` with pip: pip install line_profiler -``line_profiler`` requires each section of driver code that you want to profile as -its own independent function. Conveniently, we have already done so by defining +``line_profiler`` requires each section of driver code that you want to profile as +its own independent function. Conveniently, we have already done so by defining each loop version as its own function. To tell ``line_profiler`` which functions -to profile, just add the ``@profile`` decorator to ``ex1()``, ``ex2()`` and -``ex3()``. Note that you do not need to import ``line_profiler`` into your Ray +to profile, just add the ``@profile`` decorator to ``ex1()``, ``ex2()`` and +``ex3()``. Note that you do not need to import ``line_profiler`` into your Ray application: .. code-block:: python @@ -262,16 +262,16 @@ application: if __name__ == "__main__": main() -Then, when we want to execute our Python script from the command line, instead -of ``python your_script_here.py``, we use the following shell command to run the +Then, when we want to execute our Python script from the command line, instead +of ``python your_script_here.py``, we use the following shell command to run the script with ``line_profiler`` enabled: .. code-block:: bash - kernprof -l your_script_here.py + kernprof -l your_script_here.py -This command runs your script and prints only your script's output as usual. -``Line_profiler`` instead outputs its profiling results to a corresponding +This command runs your script and prints only your script's output as usual. +``Line_profiler`` instead outputs its profiling results to a corresponding binary file called ``your_script_here.py.lprof``. To read ``line_profiler``'s results to terminal, use this shell command: @@ -300,10 +300,10 @@ Note that execution time is given in units of 1e-06 seconds: 33 5 2508805.0 501761.0 100.0 list1.append(ray.get(func.remote())) -Notice that each hit to ``list1.append(ray.get(func.remote()))`` at line 33 -takes the full 0.5 seconds waiting for ``func()`` to finish. Meanwhile, in -``ex2()`` below, each call of ``func.remote()`` at line 40 only takes 0.127 ms, -and the majority of the time (about 1 second) is spent on waiting for ``ray.get()`` +Notice that each hit to ``list1.append(ray.get(func.remote()))`` at line 33 +takes the full 0.5 seconds waiting for ``func()`` to finish. Meanwhile, in +``ex2()`` below, each call of ``func.remote()`` at line 40 only takes 0.127 ms, +and the majority of the time (about 1 second) is spent on waiting for ``ray.get()`` at the end: @@ -323,11 +323,11 @@ at the end: 41 1 1002919.0 1002919.0 99.9 ray.get(list2) -And finally, ``line_profiler``'s output for ``ex3()``. Each call to -``func.remote()`` at line 50 still take magnitudes faster than 0.5 seconds, -showing that Ray is successfully parallelizing the remote calls. However, each -call to the local function ``other_func()`` takes the full 0.3 seconds, -totalling up to the guaranteed minimum application execution time of 1.5 +And finally, ``line_profiler``'s output for ``ex3()``. Each call to +``func.remote()`` at line 50 still take magnitudes faster than 0.5 seconds, +showing that Ray is successfully parallelizing the remote calls. However, each +call to the local function ``other_func()`` takes the full 0.3 seconds, +totalling up to the guaranteed minimum application execution time of 1.5 seconds: .. code-block:: bash @@ -351,20 +351,20 @@ seconds: Profiling Using Python's CProfile --------------------------------- -A second way to profile the performance of your Ray application is to -use Python's native cProfile `profiling module`_. Rather than tracking +A second way to profile the performance of your Ray application is to +use Python's native cProfile `profiling module`_. Rather than tracking line-by-line of your application code, cProfile can give the total runtime of each loop function, as well as list the number of calls made and -execution time of all function calls made within the profiled code. +execution time of all function calls made within the profiled code. .. _`profiling module`: https://docs.python.org/3/library/profile.html#module-cProfile -Unlike ``line_profiler`` above, this detailed list of profiled function calls -**includes** internal function calls and function calls made within Ray! +Unlike ``line_profiler`` above, this detailed list of profiled function calls +**includes** internal function calls and function calls made within Ray! -However, similar to ``line_profiler``, cProfile can be enabled with minimal -changes to your application code (given that each section of the code you want -to profile is defined as its own function). To use cProfile, add an import +However, similar to ``line_profiler``, cProfile can be enabled with minimal +changes to your application code (given that each section of the code you want +to profile is defined as its own function). To use cProfile, add an import statement, then replace calls to the loop functions as follows: .. code-block:: python @@ -385,17 +385,17 @@ statement, then replace calls to the loop functions as follows: if __name__ == "__main__": main() -Now, when executing your Python script, a cProfile list of profiled function +Now, when executing your Python script, a cProfile list of profiled function calls will be outputted to terminal for each call made to ``cProfile.run()``. -At the very top of cProfile's output gives the total execution time for +At the very top of cProfile's output gives the total execution time for ``'ex1()'``: .. code-block:: bash 601 function calls (595 primitive calls) in 2.509 seconds -Following is a snippet of profiled function calls for ``'ex1()'``. Most of -these calls are quick and take around 0.000 seconds, so the functions of +Following is a snippet of profiled function calls for ``'ex1()'``. Most of +these calls are quick and take around 0.000 seconds, so the functions of interest are the ones with non-zero execution times: .. code-block:: bash @@ -405,7 +405,7 @@ interest are the ones with non-zero execution times: 1 0.000 0.000 2.509 2.509 your_script_here.py:31(ex1) 5 0.000 0.000 0.001 0.000 remote_function.py:103(remote) 5 0.000 0.000 0.001 0.000 remote_function.py:107(_submit) - ... + ... 10 0.000 0.000 0.000 0.000 worker.py:2459(__init__) 5 0.000 0.000 2.508 0.502 worker.py:2535(get) 5 0.000 0.000 0.000 0.000 worker.py:2695(get_global_worker) @@ -414,25 +414,25 @@ interest are the ones with non-zero execution times: 5 0.000 0.000 0.000 0.000 worker.py:514(submit_task) ... -The 5 separate calls to Ray's ``get``, taking the full 0.502 seconds each call, -can be noticed at ``worker.py:2535(get)``. Meanwhile, the act of calling the -remote function itself at ``remote_function.py:103(remote)`` only takes 0.001 -seconds over 5 calls, and thus is not the source of the slow performance of +The 5 separate calls to Ray's ``get``, taking the full 0.502 seconds each call, +can be noticed at ``worker.py:2535(get)``. Meanwhile, the act of calling the +remote function itself at ``remote_function.py:103(remote)`` only takes 0.001 +seconds over 5 calls, and thus is not the source of the slow performance of ``ex1()``. Profiling Ray Actors with cProfile ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Considering that the detailed output of cProfile can be quite different depending -on what Ray functionalities we use, let us see what cProfile's output might look -like if our example involved Actors (for an introduction to Ray actors, see our -`Actor documentation here`_). +Considering that the detailed output of cProfile can be quite different depending +on what Ray functionalities we use, let us see what cProfile's output might look +like if our example involved Actors (for an introduction to Ray actors, see our +`Actor documentation here`_). .. _`Actor documentation here`: http://ray.readthedocs.io/en/latest/actors.html Now, instead of looping over five calls to a remote function like in ``ex1``, -let's create a new example and loop over five calls to a remote function +let's create a new example and loop over five calls to a remote function **inside an actor**. Our actor's remote function again just sleeps for 0.5 seconds: @@ -440,7 +440,7 @@ seconds: # Our actor @ray.remote - class Sleeper(object): + class Sleeper(object): def __init__(self): self.sleepValue = 0.5 @@ -448,7 +448,7 @@ seconds: def actor_func(self): time.sleep(self.sleepValue) -Recalling the suboptimality of ``ex1``, let's first see what happens if we +Recalling the suboptimality of ``ex1``, let's first see what happens if we attempt to perform all five ``actor_func()`` calls within a single actor: .. code-block:: python @@ -470,7 +470,7 @@ We enable cProfile on this example as follows: def main(): ray.init() - cProfile.run('ex4()') + cProfile.run('ex4()') if __name__ == "__main__": main() @@ -497,22 +497,22 @@ Running our new Actor example, cProfile's abbreviated output is as follows: 8 0.000 0.000 0.001 0.000 worker.py:514(submit_task) ... -It turns out that the entire example still took 2.5 seconds to execute, or the -time for five calls to ``actor_func()`` to run in serial. We remember in ``ex1`` -that this behavior was because we did not wait until after submitting all five +It turns out that the entire example still took 2.5 seconds to execute, or the +time for five calls to ``actor_func()`` to run in serial. We remember in ``ex1`` +that this behavior was because we did not wait until after submitting all five remote function tasks to call ``ray.get()``, but we can verify on cProfile's -output line ``worker.py:2535(get)`` that ``ray.get()`` was only called once at -the end, for 2.509 seconds. What happened? +output line ``worker.py:2535(get)`` that ``ray.get()`` was only called once at +the end, for 2.509 seconds. What happened? -It turns out Ray cannot parallelize this example, because we have only -initialized a single ``Sleeper`` actor. Because each actor is a single, -stateful worker, our entire code is submitted and ran on a single worker the +It turns out Ray cannot parallelize this example, because we have only +initialized a single ``Sleeper`` actor. Because each actor is a single, +stateful worker, our entire code is submitted and ran on a single worker the whole time. To better parallelize the actors in ``ex4``, we can take advantage that each call to ``actor_func()`` is independent, and instead create five ``Sleeper`` actors. That way, we are creating five workers -that can run in parallel, instead of creating a single worker that +that can run in parallel, instead of creating a single worker that can only handle one call to ``actor_func()`` at a time. .. code-block:: python @@ -530,7 +530,7 @@ can only handle one call to ``actor_func()`` at a time. Our example in total now takes only 1.5 seconds to run: -.. code-block:: bash +.. code-block:: bash 1378 function calls (1363 primitive calls) in 1.567 seconds @@ -553,27 +553,27 @@ Our example in total now takes only 1.5 seconds to run: Visualizing Tasks in the Ray Timeline ------------------------------------- -Profiling the performance of your Ray application doesn't need to be -an eye-straining endeavor of interpreting numbers among hundreds of -lines of text. Ray comes with its own visual web UI to visualize the +Profiling the performance of your Ray application doesn't need to be +an eye-straining endeavor of interpreting numbers among hundreds of +lines of text. Ray comes with its own visual web UI to visualize the parallelization (or lack thereof) of user tasks submitted to Ray! -This method does have its own limitations, however. The Ray Timeline +This method does have its own limitations, however. The Ray Timeline can only show timing info about Ray tasks, and not timing for normal Python functions. This can be an issue especially for debugging slow -Python code that is running on the driver, and not running as a task on -one of the workers. The other profiling techniques above are options that +Python code that is running on the driver, and not running as a task on +one of the workers. The other profiling techniques above are options that do cover profiling normal Python functions. Currently, whenever initializing Ray, a URL is generated and printed -in the terminal. This URL can be used to view Ray's web UI as a Jupyter +in the terminal. This URL can be used to view Ray's web UI as a Jupyter notebook: .. code-block:: bash ~$: python your_script_here.py - Process STDOUT and STDERR is being redirected to /tmp/raylogs/. + Process STDOUT and STDERR is being redirected to /tmp/ray/session_2018-11-01_14-31-43_27211/logs. Waiting for redis server at 127.0.0.1:61150 to respond... Waiting for redis server at 127.0.0.1:21607 to respond... Starting local scheduler with the following resources: {'CPU': 4, 'GPU': 0}. @@ -582,13 +582,13 @@ notebook: View the web UI at http://localhost:8897/notebooks/ray_ui84907.ipynb?token=025e8ab295270a57fac209204b37349fdf34e037671a13ff ====================================================================== -Ray's web UI attempts to run on localhost at port 8888, and if it fails -it tries successive ports until it finds an open port. In this above +Ray's web UI attempts to run on localhost at port 8888, and if it fails +it tries successive ports until it finds an open port. In this above example, it has opened on port 8897. -Because this web UI is only available as long as your Ray application -is currently running, you may need to add a user prompt to prevent -your Ray application from exiting once it has finished executing, +Because this web UI is only available as long as your Ray application +is currently running, you may need to add a user prompt to prevent +your Ray application from exiting once it has finished executing, such as below. You can then browse the web UI for as long as you like: .. code-block:: python @@ -606,44 +606,44 @@ such as below. You can then browse the web UI for as long as you like: main() Now, when executing your python script, you can access the Ray timeline -by copying the web UI URL into your web browser on the Ray machine. To -load the web UI in the jupyter notebook, select **Kernel -> Restart and +by copying the web UI URL into your web browser on the Ray machine. To +load the web UI in the jupyter notebook, select **Kernel -> Restart and Run All** in the jupyter menu. -The Ray timeline can be viewed in the fourth cell of the UI notebook by -using the task filter options, then clicking on the **View task timeline** +The Ray timeline can be viewed in the fourth cell of the UI notebook by +using the task filter options, then clicking on the **View task timeline** button. -For example, here are the results of executing ``ex1()``, ``ex2()``, and -``ex3()`` visualized in the Ray timeline. Each red block is a call to one -of our user-defined remote functions, namely ``func()``, which sleeps for +For example, here are the results of executing ``ex1()``, ``ex2()``, and +``ex3()`` visualized in the Ray timeline. Each red block is a call to one +of our user-defined remote functions, namely ``func()``, which sleeps for 0.5 seconds: .. image:: user-profiling-timeline.gif -(highlighted color boxes for ``ex1()``, ``ex2()``, and ``ex3()`` added for +(highlighted color boxes for ``ex1()``, ``ex2()``, and ``ex3()`` added for the sake of this example) -Note how ``ex1()`` executes all five calls to ``func()`` in serial, +Note how ``ex1()`` executes all five calls to ``func()`` in serial, while ``ex2()`` and ``ex3()`` are able to parallelize their remote -function calls. +function calls. -Because we have 4 CPUs available on our machine, we can only able to -execute up to 4 remote functions in parallel. So, the fifth call to the -remote function in ``ex2()`` must wait until the first batch of ``func()`` +Because we have 4 CPUs available on our machine, we can only able to +execute up to 4 remote functions in parallel. So, the fifth call to the +remote function in ``ex2()`` must wait until the first batch of ``func()`` calls is finished. -In ``ex3()``, because of the serial dependency on ``other_func()``, we +In ``ex3()``, because of the serial dependency on ``other_func()``, we aren't even able to use all 4 of our cores to parallelize calls to ``func()``. The time gaps between the ``func()`` blocks are a result of staggering the -calls to ``func()`` in between waiting 0.3 seconds for ``other_func()``. +calls to ``func()`` in between waiting 0.3 seconds for ``other_func()``. -Also, notice that due to the aforementioned limitation of the Ray timeline, -``other_func()``, as a driver function and not a Ray task, is never +Also, notice that due to the aforementioned limitation of the Ray timeline, +``other_func()``, as a driver function and not a Ray task, is never visualized on the Ray timeline. **For more on Ray's Web UI,** such as how to access the UI on a remote -node over ssh, or for troubleshooting installation, please see our +node over ssh, or for troubleshooting installation, please see our `Web UI documentation section`_. .. _`Web UI documentation section`: http://ray.readthedocs.io/en/latest/webui.html diff --git a/doc/source/using-ray-and-docker-on-a-cluster.md b/doc/source/using-ray-and-docker-on-a-cluster.md index 9ae39d17851ef..4e7b7a52d9bd6 100644 --- a/doc/source/using-ray-and-docker-on-a-cluster.md +++ b/doc/source/using-ray-and-docker-on-a-cluster.md @@ -1,4 +1,4 @@ -# Using Ray and Docker on a Cluster (EXPERIMENTAL) +# Using Ray and Docker on a Cluster (Experimental) Packaging and deploying an application using Docker can provide certain advantages. It can make managing dependencies easier, help ensure that each cluster node receives a uniform configuration, and facilitate swapping hardware resources between applications. diff --git a/doc/source/using-ray-on-a-cluster.rst b/doc/source/using-ray-on-a-cluster.rst index 29c2585ac7cfe..611e47b79db23 100644 --- a/doc/source/using-ray-on-a-cluster.rst +++ b/doc/source/using-ray-on-a-cluster.rst @@ -51,7 +51,6 @@ Now we've started all of the Ray processes on each node Ray. This includes - An object store on each machine. - A local scheduler on each machine. - Multiple Redis servers (on the head node). -- One global scheduler (on the head node). To run some commands, start up Python on one of the nodes in the cluster, and do the following. diff --git a/doc/source/using-ray-on-a-large-cluster.rst b/doc/source/using-ray-on-a-large-cluster.rst index c3d6d8a8d2389..b87c8c05f5125 100644 --- a/doc/source/using-ray-on-a-large-cluster.rst +++ b/doc/source/using-ray-on-a-large-cluster.rst @@ -154,7 +154,6 @@ Now you have started all of the Ray processes on each node. These include: - An object store on each machine. - A local scheduler on each machine. - Multiple Redis servers (on the head node). -- One global scheduler (on the head node). To confirm that the Ray cluster setup is working, start up Python on one of the nodes in the cluster and enter the following commands to connect to the Ray diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index d4e6c34b22179..9cdee4ff117eb 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,6 +5,7 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev -RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras +RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout +RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/examples/carla/a3c_lane_keep.py b/examples/carla/a3c_lane_keep.py deleted file mode 100644 index 1338736d23f5e..0000000000000 --- a/examples/carla/a3c_lane_keep.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-a3c": { - "run": "A3C", - "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "gamma": 0.8, - "num_workers": 1, - }, - }, -}) diff --git a/examples/carla/dqn_lane_keep.py b/examples/carla/dqn_lane_keep.py deleted file mode 100644 index 2746a1c4bbd89..0000000000000 --- a/examples/carla/dqn_lane_keep.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": True, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-dqn": { - "run": "DQN", - "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "timesteps_per_iteration": 100, - "learning_starts": 1000, - "schedule_max_timesteps": 100000, - "gamma": 0.8, - "tf_session_args": { - "gpu_options": {"allow_growth": True}, - }, - }, - }, -}) diff --git a/examples/carla/ppo_lane_keep.py b/examples/carla/ppo_lane_keep.py deleted file mode 100644 index 25e5acbf328c4..0000000000000 --- a/examples/carla/ppo_lane_keep.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-ppo": { - "run": "PPO", - "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "num_workers": 1, - "timesteps_per_batch": 2000, - "min_steps_per_task": 100, - "lambda": 0.95, - "clip_param": 0.2, - "num_sgd_iter": 20, - "sgd_stepsize": 0.0001, - "sgd_batchsize": 32, - "devices": ["/gpu:0"], - "tf_session_args": { - "gpu_options": {"allow_growth": True} - } - }, - }, -}) diff --git a/examples/carla/scenarios.py b/examples/carla/scenarios.py deleted file mode 100644 index e6494af1830d0..0000000000000 --- a/examples/carla/scenarios.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" - - -TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] -TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] - - -def build_scenario( - city, start, end, vehicles, pedestrians, max_steps, weathers): - return { - "city": city, - "num_vehicles": vehicles, - "num_pedestrians": pedestrians, - "weather_distribution": weathers, - "start_pos_id": start, - "end_pos_id": end, - "max_steps": max_steps, - } - - -# Simple scenario for Town02 that involves driving down a road -DEFAULT_SCENARIO = build_scenario( - city="Town02", start=36, end=40, vehicles=20, pedestrians=40, - max_steps=200, weathers=[0]) - -# Simple scenario for Town02 that involves driving down a road -LANE_KEEP = build_scenario( - city="Town02", start=36, end=40, vehicles=0, pedestrians=0, - max_steps=2000, weathers=[0]) - -# Scenarios from the CoRL2017 paper -POSES_TOWN1_STRAIGHT = [ - [36, 40], [39, 35], [110, 114], [7, 3], [0, 4], - [68, 50], [61, 59], [47, 64], [147, 90], [33, 87], - [26, 19], [80, 76], [45, 49], [55, 44], [29, 107], - [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], - [20, 107], [78, 70], [95, 102], [68, 44], [45, 69]] - - -POSES_TOWN1_ONE_CURVE = [ - [138, 17], [47, 16], [26, 9], [42, 49], [140, 124], - [85, 98], [65, 133], [137, 51], [76, 66], [46, 39], - [40, 60], [0, 29], [4, 129], [121, 140], [2, 129], - [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], - [84, 69], [47, 79], [110, 15], [130, 17], [0, 17]] - -POSES_TOWN1_NAV = [ - [105, 29], [27, 130], [102, 87], [132, 27], [24, 44], - [96, 26], [34, 67], [28, 1], [140, 134], [105, 9], - [148, 129], [65, 18], [21, 16], [147, 97], [42, 51], - [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], - [111, 64], [79, 45], [84, 69], [73, 31], [37, 81]] - - -POSES_TOWN2_STRAIGHT = [ - [38, 34], [4, 2], [12, 10], [62, 55], [43, 47], - [64, 66], [78, 76], [59, 57], [61, 18], [35, 39], - [12, 8], [0, 18], [75, 68], [54, 60], [45, 49], - [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], - [54, 63], [51, 42], [16, 19], [17, 26], [77, 68]] - -POSES_TOWN2_ONE_CURVE = [ - [37, 76], [8, 24], [60, 69], [38, 10], [21, 1], - [58, 71], [74, 32], [44, 0], [71, 16], [14, 24], - [34, 11], [43, 14], [75, 16], [80, 21], [3, 23], - [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], - [40, 63], [58, 76], [79, 55], [16, 61], [27, 11]] - -POSES_TOWN2_NAV = [ - [19, 66], [79, 14], [19, 57], [23, 1], - [53, 76], [42, 13], [31, 71], [33, 5], - [54, 30], [10, 61], [66, 3], [27, 12], - [79, 19], [2, 29], [16, 14], [5, 57], - [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], - [51, 81], [77, 68], [56, 65], [43, 54]] - -TOWN1_STRAIGHT = [ - build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_STRAIGHT] - -TOWN1_ONE_CURVE = [ - build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_ONE_CURVE] - -TOWN1_NAVIGATION = [ - build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV] - -TOWN1_NAVIGATION_DYNAMIC = [ - build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV] - -TOWN2_STRAIGHT = [ - build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT] - -TOWN2_STRAIGHT_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT] - -TOWN2_ONE_CURVE = [ - build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_ONE_CURVE] - -TOWN2_NAVIGATION = [ - build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV] - -TOWN2_NAVIGATION_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV] - -TOWN1_ALL = ( - TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + - TOWN1_NAVIGATION_DYNAMIC) - -TOWN2_ALL = ( - TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + - TOWN2_NAVIGATION_DYNAMIC) diff --git a/examples/custom_env/README b/examples/custom_env/README deleted file mode 100644 index 75ffcad88fb35..0000000000000 --- a/examples/custom_env/README +++ /dev/null @@ -1 +0,0 @@ -Example of using a custom gym env with RLlib. diff --git a/java/README.rst b/java/README.rst index 95ab961e769dc..e016169357874 100644 --- a/java/README.rst +++ b/java/README.rst @@ -7,6 +7,7 @@ Ray will read your configurations in the following order: * Java system properties: e.g., ``-Dray.home=/path/to/ray``. * A ``ray.conf`` file in the classpath: `example `_. +* Customise your own ``ray.conf`` path using system property ``-Dray.config=/path/to/ray.conf`` For all available config items and default values, see `this file `_. diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index 053f01d5534d3..7e252274ef735 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -41,7 +41,10 @@ public static synchronized void init(RayRuntimeFactory factory) { * Shutdown Ray runtime. */ public static void shutdown() { - runtime.shutdown(); + if (runtime != null) { + runtime.shutdown(); + runtime = null; + } } /** diff --git a/java/api/src/main/java/org/ray/api/RayCall.java b/java/api/src/main/java/org/ray/api/RayCall.java index ef40a238c0e29..967830199402c 100644 --- a/java/api/src/main/java/org/ray/api/RayCall.java +++ b/java/api/src/main/java/org/ray/api/RayCall.java @@ -2,6 +2,7 @@ package org.ray.api; +import org.ray.api.function.RayFunc; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; import org.ray.api.function.RayFunc2; @@ -9,6 +10,9 @@ import org.ray.api.function.RayFunc4; import org.ray.api.function.RayFunc5; import org.ray.api.function.RayFunc6; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; /** * This class provides type-safe interfaces for `Ray.call` and `Ray.createActor`. @@ -20,511 +24,1019 @@ class RayCall { // ======================================= public static RayObject call(RayFunc0 f) { Object[] args = new Object[]{}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc0 f, CallOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc1 f, T0 t0) { Object[] args = new Object[]{t0}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc1 f, RayObject t0) { Object[] args = new Object[]{t0}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc1 f, T0 t0, CallOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc1 f, RayObject t0, CallOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc2 f, T0 t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, T0 t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, RayObject t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, RayObject t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc2 f, T0 t0, T1 t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, T0 t0, RayObject t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, RayObject t0, T1 t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, RayObject t0, RayObject t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc3 f, T0 t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc3 f, T0 t0, T1 t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, T1 t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); } // =========================================== // Methods for remote actor method invocation. @@ -786,510 +1298,1018 @@ public static RayObject call(RayFunc6 RayActor createActor(RayFunc0 f) { Object[] args = new Object[]{}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc0 f, ActorCreationOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc1 f, T0 t0) { Object[] args = new Object[]{t0}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc1 f, RayObject t0) { Object[] args = new Object[]{t0}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc1 f, T0 t0, ActorCreationOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc1 f, RayObject t0, ActorCreationOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc2 f, T0 t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, T0 t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, RayObject t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, RayObject t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc2 f, T0 t0, T1 t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, T0 t0, RayObject t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, RayObject t0, T1 t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, RayObject t0, RayObject t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); } } diff --git a/java/api/src/main/java/org/ray/api/annotation/RayRemote.java b/java/api/src/main/java/org/ray/api/annotation/RayRemote.java index a47e0768f0fbf..197ee663f58a0 100644 --- a/java/api/src/main/java/org/ray/api/annotation/RayRemote.java +++ b/java/api/src/main/java/org/ray/api/annotation/RayRemote.java @@ -15,10 +15,4 @@ @Target({ElementType.METHOD, ElementType.TYPE}) public @interface RayRemote { - /** - * Defines the quantity of various custom resources to reserve - * for this task or for the lifetime of the actor. - * @return an array of custom resource items. - */ - ResourceItem[] resources() default {}; } diff --git a/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java b/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java deleted file mode 100644 index f4895eba6164c..0000000000000 --- a/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.ray.api.annotation; - - -import java.lang.annotation.Documented; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * Represents a custom resource, including its name and quantity. - */ -@Documented -@Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.ANNOTATION_TYPE) -public @interface ResourceItem { - - /** - * Name of this resource, must not be null or empty. - */ - String name(); - - /** - * Quantity of this resource. - */ - double value() default 0; - -} diff --git a/java/api/src/main/java/org/ray/api/id/UniqueId.java b/java/api/src/main/java/org/ray/api/id/UniqueId.java index 0d32d0f8f3c4b..f93bdc737229e 100644 --- a/java/api/src/main/java/org/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/org/ray/api/id/UniqueId.java @@ -112,6 +112,6 @@ public boolean equals(Object obj) { @Override public String toString() { - return DatatypeConverter.printHexBinary(id); + return DatatypeConverter.printHexBinary(id).toLowerCase(); } } diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java new file mode 100644 index 0000000000000..20db30944e513 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -0,0 +1,18 @@ +package org.ray.api.options; + +import java.util.Map; + +/** + * The options for creating actor. + */ +public class ActorCreationOptions extends BaseTaskOptions { + + public ActorCreationOptions() { + super(); + } + + public ActorCreationOptions(Map resources) { + super(resources); + } + +} diff --git a/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java new file mode 100644 index 0000000000000..65494d532a687 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java @@ -0,0 +1,20 @@ +package org.ray.api.options; + +import java.util.HashMap; +import java.util.Map; + +/** + * The options class for RayCall or ActorCreation. + */ +public abstract class BaseTaskOptions { + public Map resources; + + public BaseTaskOptions() { + resources = new HashMap<>(); + } + + public BaseTaskOptions(Map resources) { + this.resources = resources; + } + +} diff --git a/java/api/src/main/java/org/ray/api/options/CallOptions.java b/java/api/src/main/java/org/ray/api/options/CallOptions.java new file mode 100644 index 0000000000000..84adfc122e04a --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/CallOptions.java @@ -0,0 +1,18 @@ +package org.ray.api.options; + +import java.util.Map; + +/** + * The options for RayCall. + */ +public class CallOptions extends BaseTaskOptions { + + public CallOptions() { + super(); + } + + public CallOptions(Map resources) { + super(resources); + } + +} diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index d609d4de593d1..7c12c3543c04a 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -6,6 +6,9 @@ import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; /** * Base interface of a Ray runtime. @@ -65,9 +68,10 @@ public interface RayRuntime { * * @param func The remote function to run. * @param args The arguments of the remote function. + * @param options The options for this call. * @return The result object. */ - RayObject call(RayFunc func, Object[] args); + RayObject call(RayFunc func, Object[] args, CallOptions options); /** * Invoke a remote function on an actor. @@ -85,7 +89,9 @@ public interface RayRuntime { * @param actorFactoryFunc A remote function whose return value is the actor object. * @param args The arguments for the remote function. * @param The type of the actor object. + * @param options The options for creating actor. * @return A handle to the actor. */ - RayActor createActor(RayFunc actorFactoryFunc, Object[] args); + RayActor createActor(RayFunc actorFactoryFunc, Object[] args, + ActorCreationOptions options); } diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index 619c24e1466f0..0422332258dfa 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -10,5 +10,5 @@ - + diff --git a/java/doc/installation.rst b/java/doc/installation.rst index fca3b12e7c971..8daec29ace403 100644 --- a/java/doc/installation.rst +++ b/java/doc/installation.rst @@ -26,7 +26,7 @@ For Ubuntu users, run the following commands: # If you are on Ubuntu 14.04, you need the following. pip install cmake - pip install cython + pip install cython==0.27.3 For macOS users, run the following commands: :: @@ -34,7 +34,7 @@ For macOS users, run the following commands: brew update brew install maven cmake pkg-config automake autoconf libtool openssl bison wget - pip install cython + pip install cython==0.27.3 Build Ray ^^^^^^^^^ diff --git a/java/prepare.sh b/java/prepare.sh index 807301a74edba..9554e500a8edd 100755 --- a/java/prepare.sh +++ b/java/prepare.sh @@ -42,15 +42,15 @@ fi # echo "ray_dir = $ray_dir" declare -a nativeBinaries=( - "./src/common/thirdparty/redis/src/redis-server" + "./src/ray/thirdparty/redis/src/redis-server" "./src/plasma/plasma_store_server" "./src/ray/raylet/raylet" "./src/ray/raylet/raylet_monitor" ) declare -a nativeLibraries=( - "./src/common/redis_module/libray_redis_module.so" - "./src/local_scheduler/liblocal_scheduler_library_java.*" + "./src/ray/gcs/redis_module/libray_redis_module.so" + "./src/ray/raylet/liblocal_scheduler_library_java.*" "./src/plasma/libplasma_java.*" "./src/ray/raylet/*lib.a" ) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index b035f3b52bc0c..10dc172fd4d99 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -10,8 +10,12 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.runtime.config.RayConfig; import org.ray.runtime.functionmanager.FunctionManager; @@ -22,8 +26,7 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; -import org.ray.runtime.util.UniqueIdHelper; -import org.ray.runtime.util.exception.TaskExecutionException; +import org.ray.runtime.util.UniqueIdUtil; import org.ray.runtime.util.logger.RayLog; /** @@ -48,7 +51,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - functionManager = new FunctionManager(); + functionManager = new FunctionManager(rayConfig.driverResourcePath); worker = new Worker(this); workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.driverId); } @@ -63,7 +66,7 @@ public AbstractRayRuntime(RayConfig rayConfig) { @Override public RayObject put(T obj) { - UniqueId objectId = UniqueIdHelper.computePutId( + UniqueId objectId = UniqueIdUtil.computePutId( workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); put(objectId, obj); @@ -72,12 +75,12 @@ public RayObject put(T obj) { public void put(UniqueId objectId, T obj) { UniqueId taskId = workerContext.getCurrentTask().taskId; - RayLog.core.info("Putting object {}, for task {} ", objectId, taskId); + RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj, null); } @Override - public T get(UniqueId objectId) throws TaskExecutionException { + public T get(UniqueId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); return ret.get(0); } @@ -85,6 +88,8 @@ public T get(UniqueId objectId) throws TaskExecutionException { @Override public List get(List objectIds) { boolean wasBlocked = false; + // TODO(swang): If we are not on the main thread, then we should generate a + // random task ID to pass to the backend. UniqueId taskId = workerContext.getCurrentTask().taskId; try { @@ -94,7 +99,7 @@ public List get(List objectIds) { List> fetchBatches = splitIntoBatches(objectIds, FETCH_BATCH_SIZE); for (List batch : fetchBatches) { - rayletClient.reconstructObjects(batch, true); + rayletClient.fetchOrReconstruct(batch, true, taskId); } // Get the objects. We initially try to get the objects immediately. @@ -119,7 +124,7 @@ public List get(List objectIds) { splitIntoBatches(unreadyList, FETCH_BATCH_SIZE); for (List batch : reconstructBatches) { - rayletClient.reconstructObjects(batch, false); + rayletClient.fetchOrReconstruct(batch, false, taskId); } List> results = objectStoreProxy @@ -146,7 +151,7 @@ public List get(List objectIds) { } return finalRet; - } catch (TaskExecutionException e) { + } catch (RayException e) { RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get with Exception", e); throw e; @@ -154,7 +159,7 @@ public List get(List objectIds) { // If there were objects that we weren't able to get locally, let the local // scheduler know that we're now unblocked. if (wasBlocked) { - rayletClient.notifyUnblocked(); + rayletClient.notifyUnblocked(taskId); } } } @@ -182,12 +187,15 @@ private List> splitIntoBatches(List objectIds, int batc @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - return rayletClient.wait(waitList, numReturns, timeoutMs); + // TODO(swang): If we are not on the main thread, then we should generate a + // random task ID to pass to the backend. + return rayletClient.wait(waitList, numReturns, timeoutMs, + workerContext.getCurrentTask().taskId); } @Override - public RayObject call(RayFunc func, Object[] args) { - TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false); + public RayObject call(RayFunc func, Object[] args, CallOptions options) { + TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options); rayletClient.submitTask(spec); return new RayObjectImpl(spec.returnIds[0]); } @@ -198,7 +206,7 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) { throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName()); } RayActorImpl actorImpl = (RayActorImpl)actor; - TaskSpec spec = createTaskSpec(func, actorImpl, args, false); + TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null); spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); actorImpl.setTaskCursor(spec.returnIds[1]); rayletClient.submitTask(spec); @@ -207,8 +215,10 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) { @Override @SuppressWarnings("unchecked") - public RayActor createActor(RayFunc actorFactoryFunc, Object[] args) { - TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, args, true); + public RayActor createActor(RayFunc actorFactoryFunc, + Object[] args, ActorCreationOptions options) { + TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, + args, true, options); RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); @@ -222,7 +232,7 @@ public RayActor createActor(RayFunc actorFactoryFunc, Object[] args) { private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { UniqueId[] ret = new UniqueId[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1); + ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1); } return ret; } @@ -236,11 +246,10 @@ private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { * @return A TaskSpec object. */ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, - boolean isActorCreationTask) { + boolean isActorCreationTask, BaseTaskOptions taskOptions) { final TaskSpec current = workerContext.getCurrentTask(); UniqueId taskId = rayletClient.generateTaskId(current.driverId, - current.taskId, - workerContext.nextCallIndex()); + current.taskId, workerContext.nextCallIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; UniqueId[] returnIds = genReturnIds(taskId, numReturns); @@ -249,6 +258,18 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, actorCreationId = returnIds[0]; } + Map resources; + if (null == taskOptions) { + resources = new HashMap<>(); + } else { + resources = new HashMap<>(taskOptions.resources); + } + + if (!resources.containsKey(ResourceUtil.CPU_LITERAL) + && !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) { + resources.put(ResourceUtil.CPU_LITERAL, 0.0); + } + RayFunction rayFunction = functionManager.getFunction(current.driverId, func); return new TaskSpec( current.driverId, @@ -261,7 +282,7 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, actor.increaseTaskCounter(), ArgumentsBuilder.wrap(args), returnIds, - ResourceUtil.getResourcesMapFromArray(rayFunction.getRayRemoteAnnotation()), + resources, rayFunction.getFunctionDescriptor() ); } @@ -286,4 +307,3 @@ public FunctionManager getFunctionManager() { return functionManager; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index a2ef237e28068..d4d90f24ece27 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -6,11 +6,14 @@ import com.typesafe.config.Config; import com.typesafe.config.ConfigException; import com.typesafe.config.ConfigFactory; + +import java.io.File; import java.util.List; import java.util.Map; import org.ray.api.id.UniqueId; import org.ray.runtime.util.NetworkUtil; import org.ray.runtime.util.ResourceUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +38,7 @@ public class RayConfig { public final boolean redirectOutput; public final List libraryPath; public final List classpath; + public final List jvmParameters; private String redisAddress; private String redisIp; @@ -51,6 +55,8 @@ public class RayConfig { public final String redisModulePath; public final String plasmaStoreExecutablePath; public final String rayletExecutablePath; + public final String driverResourcePath; + public final String pythonWorkerCommand; private void validate() { if (workerMode == WorkerMode.WORKER) { @@ -126,6 +132,18 @@ public RayConfig(Config config) { List customLibraryPath = config.getStringList("ray.library.path"); // custom classpath classpath = config.getStringList("ray.classpath"); + // custom worker jvm parameters + if (config.hasPath("ray.worker.jvm-parameters")) { + jvmParameters = config.getStringList("ray.worker.jvm-parameters"); + } else { + jvmParameters = ImmutableList.of(); + } + + if (config.hasPath("ray.worker.python-command")) { + pythonWorkerCommand = config.getString("ray.worker.python-command"); + } else { + pythonWorkerCommand = null; + } // redis configurations String redisAddress = config.getString("ray.redis.address"); @@ -147,15 +165,22 @@ public RayConfig(Config config) { // library path this.libraryPath = new ImmutableList.Builder().add( rayHome + "/build/src/plasma", - rayHome + "/build/src/local_scheduler" + rayHome + "/build/src/ray/raylet" ).addAll(customLibraryPath).build(); redisServerExecutablePath = rayHome + - "/build/src/common/thirdparty/redis/src/redis-server"; - redisModulePath = rayHome + "/build/src/common/redis_module/libray_redis_module.so"; + "/build/src/ray/thirdparty/redis/src/redis-server"; + redisModulePath = rayHome + "/build/src/ray/gcs/redis_module/libray_redis_module.so"; plasmaStoreExecutablePath = rayHome + "/build/src/plasma/plasma_store_server"; rayletExecutablePath = rayHome + "/build/src/ray/raylet/raylet"; + // driver resource path + if (config.hasPath("ray.driver.resource-path")) { + driverResourcePath = config.getString("ray.driver.resource-path"); + } else { + driverResourcePath = null; + } + // validate config validate(); LOGGER.debug("Created config: {}", this); @@ -219,9 +244,16 @@ public String toString() { */ public static RayConfig create() { ConfigFactory.invalidateCaches(); - Config config = ConfigFactory.systemProperties() - .withFallback(ConfigFactory.load(CUSTOM_CONFIG_FILE)) - .withFallback(ConfigFactory.load(DEFAULT_CONFIG_FILE)); + Config config = ConfigFactory.systemProperties(); + String configPath = System.getProperty("ray.config"); + if (StringUtil.isNullOrEmpty(configPath)) { + LOGGER.info("Loading config from \"ray.conf\" file in classpath."); + config = config.withFallback(ConfigFactory.load(CUSTOM_CONFIG_FILE)); + } else { + LOGGER.info("Loading config from " + configPath + "."); + config = config.withFallback(ConfigFactory.parseFile(new File(configPath))); + } + config = config.withFallback(ConfigFactory.load(DEFAULT_CONFIG_FILE)); return new RayConfig(config); } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index e586741641ae0..d7698c22aa7fb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -15,20 +15,25 @@ import org.objectweb.asm.Type; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.runtime.util.JarLoader; import org.ray.runtime.util.LambdaUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Manages functions by driver id. */ public class FunctionManager { + private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class); + static final String CONSTRUCTOR_NAME = ""; /** * Cache from a RayFunc object to its corresponding FunctionDescriptor. Because * `LambdaUtils.getSerializedLambda` is expensive. */ - private static final ThreadLocal, FunctionDescriptor>> + private static final ThreadLocal, FunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); /** @@ -36,6 +41,21 @@ public class FunctionManager { */ private Map driverFunctionTables = new HashMap<>(); + /** + * The resource path which we can load the driver's jar resources. + */ + private String driverResourcePath; + + /** + * Construct a FunctionManager with the specified driver resource path. + * + * @param driverResourcePath The specified driver resource that + * can store the driver's resources. + */ + public FunctionManager(String driverResourcePath) { + this.driverResourcePath = driverResourcePath; + } + /** * Get the RayFunction from a RayFunc instance (a lambda). * @@ -51,6 +71,7 @@ public RayFunction getFunction(UniqueId driverId, RayFunc func) { final String methodName = serializedLambda.getImplMethodName(); final String typeDescriptor = serializedLambda.getImplMethodSignature(); functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor); + RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor); } return getFunction(driverId, functionDescriptor); } @@ -65,8 +86,17 @@ public RayFunction getFunction(UniqueId driverId, RayFunc func) { public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) { DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId); if (driverFunctionTable == null) { - //TODO(hchen): distinguish class loader by driver id. - ClassLoader classLoader = getClass().getClassLoader(); + String resourcePath = driverResourcePath + "/" + driverId.toString() + "/"; + ClassLoader classLoader; + + if (driverResourcePath != null && !driverResourcePath.isEmpty()) { + classLoader = JarLoader.loadJars(resourcePath, false); + LOGGER.info("Succeeded to load driver({}) resource. Resource path is {}", + driverId, resourcePath); + } else { + classLoader = getClass().getClassLoader(); + } + driverFunctionTable = new DriverFunctionTable(classLoader); driverFunctionTables.put(driverId, driverFunctionTable); } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java index 3d0704c6bf484..2f39ec3dc8db4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java @@ -58,12 +58,17 @@ public FunctionDescriptor getFunctionDescriptor() { } public RayRemote getRayRemoteAnnotation() { - RayRemote rayRemote = executable.getAnnotation(RayRemote.class); - if (rayRemote == null) { - // If the method doesn't have a annotation, get the annotation from - // its wrapping class. + RayRemote rayRemote; + + // If this method is a constructor, the task of it should be a actorCreationTask. + // And the annotation of actorCreationTask should inherit from class. + // Otherwise, it's a normal method, and it shouldn't inherit annotation from class. + if (isConstructor()) { rayRemote = executable.getDeclaringClass().getAnnotation(RayRemote.class); + } else { + rayRemote = executable.getAnnotation(RayRemote.class); } + return rayRemote; } diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java b/java/runtime/src/main/java/org/ray/runtime/generated/Language.java similarity index 51% rename from java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java rename to java/runtime/src/main/java/org/ray/runtime/generated/Language.java index e5e53614aa8a7..34604374dd441 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/Language.java @@ -2,13 +2,13 @@ package org.ray.runtime.generated; -public final class TaskLanguage { - private TaskLanguage() { } +public final class Language { + private Language() { } public static final int PYTHON = 0; - public static final int JAVA = 1; + public static final int CPP = 1; + public static final int JAVA = 2; - public static final String[] names = { "PYTHON", "JAVA", }; + public static final String[] names = { "PYTHON", "CPP", "JAVA", }; public static String name(int e) { return names[e]; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java index 8c0512afbc4fc..01113096036fc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java @@ -48,9 +48,12 @@ public final class TaskInfo extends Table { public ResourcePair requiredResources(int j) { return requiredResources(new ResourcePair(), j); } public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; } - public int language() { int o = __offset(32); return o != 0 ? bb.getInt(o + bb_pos) : 0; } - public String functionDescriptor(int j) { int o = __offset(34); return o != 0 ? __string(__vector(o) + j * 4) : null; } - public int functionDescriptorLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; } + public ResourcePair requiredPlacementResources(int j) { return requiredPlacementResources(new ResourcePair(), j); } + public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int requiredPlacementResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; } + public int language() { int o = __offset(34); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public String functionDescriptor(int j) { int o = __offset(36); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int functionDescriptorLength() { int o = __offset(36); return o != 0 ? __vector_len(o) : 0; } public static int createTaskInfo(FlatBufferBuilder builder, int driver_idOffset, @@ -67,11 +70,13 @@ public static int createTaskInfo(FlatBufferBuilder builder, int argsOffset, int returnsOffset, int required_resourcesOffset, + int required_placement_resourcesOffset, int language, int function_descriptorOffset) { - builder.startObject(16); + builder.startObject(17); TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset); TaskInfo.addLanguage(builder, language); + TaskInfo.addRequiredPlacementResources(builder, required_placement_resourcesOffset); TaskInfo.addRequiredResources(builder, required_resourcesOffset); TaskInfo.addReturns(builder, returnsOffset); TaskInfo.addArgs(builder, argsOffset); @@ -89,7 +94,7 @@ public static int createTaskInfo(FlatBufferBuilder builder, return TaskInfo.endTaskInfo(builder); } - public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(16); } + public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(17); } public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); } public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); } public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); } @@ -110,8 +115,11 @@ public static int createTaskInfo(FlatBufferBuilder builder, public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(13, requiredResourcesOffset, 0); } public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(14, language, 0); } - public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(15, functionDescriptorOffset, 0); } + public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(14, requiredPlacementResourcesOffset, 0); } + public static int createRequiredPlacementResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startRequiredPlacementResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(15, language, 0); } + public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(16, functionDescriptorOffset, 0); } public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endTaskInfo(FlatBufferBuilder builder) { @@ -136,4 +144,3 @@ public ByteBuffer returnsAsByteBuffer(int j) { return src; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index b497f5c44b148..5f8221ff6f028 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -4,10 +4,11 @@ import java.util.List; import org.apache.arrow.plasma.ObjectStoreLink; import org.apache.commons.lang3.tuple.Pair; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; -import org.ray.runtime.util.exception.TaskExecutionException; +import org.ray.runtime.util.UniqueIdUtil; /** * Object store proxy, which handles serialization and deserialization, and utilize a {@code @@ -15,9 +16,10 @@ */ public class ObjectStoreProxy { + private static final int GET_TIMEOUT_MS = 1000; + private final AbstractRayRuntime runtime; private final ObjectStoreLink store; - private final int getTimeoutMs = 1000; public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { this.runtime = runtime; @@ -25,18 +27,18 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { } public Pair get(UniqueId objectId, boolean isMetadata) - throws TaskExecutionException { - return get(objectId, getTimeoutMs, isMetadata); + throws RayException { + return get(objectId, GET_TIMEOUT_MS, isMetadata); } public Pair get(UniqueId id, int timeoutMs, boolean isMetadata) - throws TaskExecutionException { + throws RayException { byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata); if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); store.release(id.getBytes()); - if (t instanceof TaskExecutionException) { - throw (TaskExecutionException) t; + if (t instanceof RayException) { + throw (RayException) t; } return Pair.of(t, GetStatus.SUCCESS); } else { @@ -45,21 +47,21 @@ public Pair get(UniqueId id, int timeoutMs, boolean isMetadata } public List> get(List objectIds, boolean isMetadata) - throws TaskExecutionException { - return get(objectIds, getTimeoutMs, isMetadata); + throws RayException { + return get(objectIds, GET_TIMEOUT_MS, isMetadata); } public List> get(List ids, int timeoutMs, boolean isMetadata) - throws TaskExecutionException { - List objs = store.get(getIdBytes(ids), timeoutMs, isMetadata); + throws RayException { + List objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); List> ret = new ArrayList<>(); for (int i = 0; i < objs.size(); i++) { byte[] obj = objs.get(i); if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); store.release(ids.get(i).getBytes()); - if (t instanceof TaskExecutionException) { - throw (TaskExecutionException) t; + if (t instanceof RayException) { + throw (RayException) t; } ret.add(Pair.of(t, GetStatus.SUCCESS)); } else { @@ -69,15 +71,6 @@ public List> get(List ids, int timeoutMs, boole return ret; } - private static byte[][] getIdBytes(List objectIds) { - int size = objectIds.size(); - byte[][] ids = new byte[size][]; - for (int i = 0; i < size; i++) { - ids[i] = objectIds.get(i).getBytes(); - } - return ids; - } - public void put(UniqueId id, Object obj, Object metadata) { store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 95a8abdf4274d..dbe2cd3b6329b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -6,6 +6,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.objectstore.MockObjectStore; @@ -66,12 +67,13 @@ public TaskSpec getTask() { } @Override - public void reconstructObjects(List objectIds, boolean fetchOnly) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + UniqueId currentTaskId) throws RayException { } @Override - public void notifyUnblocked() { + public void notifyUnblocked(UniqueId currentTaskId) { } @@ -81,7 +83,8 @@ public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int tas } @Override - public WaitResult wait(List> waitFor, int numReturns, int timeoutMs) { + public WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId) { return new WaitResult( waitFor, ImmutableList.of() diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index baa32a1425334..3e3f4f1e72918 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,7 @@ import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -15,13 +16,15 @@ public interface RayletClient { TaskSpec getTask(); - void reconstructObjects(List objectIds, boolean fetchOnly); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId) + throws RayException; - void notifyUnblocked(); + void notifyUnblocked(UniqueId currentTaskId); UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex); - WaitResult wait(List> waitFor, int numReturns, int timeoutMs); + WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId); void freePlasmaObjects(List objectIds, boolean localOnly); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 1a78f22debec9..cd4f3fd313c6d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -10,15 +10,16 @@ import java.util.Map; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.generated.Arg; +import org.ray.runtime.generated.Language; import org.ray.runtime.generated.ResourcePair; import org.ray.runtime.generated.TaskInfo; -import org.ray.runtime.generated.TaskLanguage; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdHelper; +import org.ray.runtime.util.UniqueIdUtil; import org.ray.runtime.util.logger.RayLog; public class RayletClientImpl implements RayletClient { @@ -44,13 +45,15 @@ public RayletClientImpl(String schedulerSockName, UniqueId clientId, } @Override - public WaitResult wait(List> waitFor, int numReturns, int timeoutMs) { + public WaitResult wait(List> waitFor, int numReturns, int + timeoutMs, UniqueId currentTaskId) { List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); } - boolean[] ready = nativeWaitObject(client, getIdBytes(ids), numReturns, timeoutMs, false); + boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), + numReturns, timeoutMs, false, currentTaskId.getBytes()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -86,12 +89,17 @@ public TaskSpec getTask() { } @Override - public void reconstructObjects(List objectIds, boolean fetchOnly) { - if (RayLog.core.isInfoEnabled()) { - RayLog.core.info("Reconstructing objects for task {}, object IDs are {}", - UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds); + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + UniqueId currentTaskId) throws RayException { + if (RayLog.core.isDebugEnabled()) { + RayLog.core.debug("Blocked on objects for task {}, object IDs are {}", + UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); + } + int ret = nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + fetchOnly, currentTaskId.getBytes()); + if (ret != 0) { + throw new RayException("Connection closed by Raylet"); } - nativeReconstructObjects(client, getIdBytes(objectIds), fetchOnly); } @Override @@ -101,13 +109,13 @@ public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int tas } @Override - public void notifyUnblocked() { - nativeNotifyUnblocked(client); + public void notifyUnblocked(UniqueId currentTaskId) { + nativeNotifyUnblocked(client, currentTaskId.getBytes()); } @Override public void freePlasmaObjects(List objectIds, boolean localOnly) { - byte[][] objectIdsArray = getIdBytes(objectIds); + byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly); } @@ -168,7 +176,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer()); final int parentCounter = task.parentCounter; final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer()); - final int actorCreateDummyIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer()); + final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; @@ -209,6 +217,11 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue()); } int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets); + + int[] requiredPlacementResourcesOffsets = new int[0]; + int requiredPlacementResourcesOffset = + fbb.createVectorOfTables(requiredPlacementResourcesOffsets); + int[] functionDescriptorOffsets = new int[]{ fbb.createString(task.functionDescriptor.className), fbb.createString(task.functionDescriptor.name), @@ -222,7 +235,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { actorCreateIdOffset, actorCreateDummyIdOffset, actorIdOffset, actorHandleIdOffset, actorCounter, false, functionIdOffset, - argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA, + argsOffset, returnsOffset, requiredResourcesOffset, + requiredPlacementResourcesOffset, Language.JAVA, functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); @@ -236,15 +250,6 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { return buffer; } - private static byte[][] getIdBytes(List objectIds) { - int size = objectIds.size(); - byte[][] ids = new byte[size][]; - for (int i = 0; i < size; i++) { - ids[i] = objectIds.get(i).getBytes(); - } - return ids; - } - public void destroy() { nativeDestroy(client); } @@ -258,8 +263,8 @@ public void destroy() { /// 1) pushd $Dir/java/runtime/target/classes /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl /// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h - /// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/local_scheduler/lib/java/ - /// 5) vim $Dir/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc + /// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/ray/raylet/lib/java/ + /// 5) vim $Dir/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc /// 6) popd private static native long nativeInit(String localSchedulerSocket, byte[] workerId, @@ -273,15 +278,15 @@ private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBu private static native void nativeDestroy(long client); - private static native void nativeReconstructObjects(long client, byte[][] objectIds, - boolean fetchOnly); + private static native int nativeFetchOrReconstruct(long client, byte[][] objectIds, + boolean fetchOnly, byte[] currentTaskId); - private static native void nativeNotifyUnblocked(long client); + private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId); private static native void nativePutObject(long client, byte[] taskId, byte[] objectId); private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, - int numReturns, int timeout, boolean waitLocal); + int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId); private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 03429c963a3a5..56940e33cbcfd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -179,13 +179,16 @@ private void startRaylet() { rayConfig.rayletExecutablePath, rayConfig.rayletSocketName, rayConfig.objectStoreSocketName, + "0", // The object manager port. + "0", // The node manager port. rayConfig.nodeIp, rayConfig.getRedisIp(), rayConfig.getRedisPort().toString(), "0", // number of initial workers String.valueOf(maximumStartupConcurrency), ResourceUtil.getResourcesStringFromMap(rayConfig.resources), - "", // python worker command + "", // The internal config list. + buildPythonWorkerCommand(), // python worker command buildWorkerCommandRaylet() // java worker command ); @@ -205,8 +208,8 @@ private String buildWorkerCommandRaylet() { // Generate classpath based on current classpath + user-defined classpath. String classpath = concatPath(Stream.concat( - Stream.of(System.getProperty("java.class.path").split(":")), - rayConfig.classpath.stream() + rayConfig.classpath.stream(), + Stream.of(System.getProperty("java.class.path").split(":")) )); cmd.add(classpath); @@ -227,6 +230,8 @@ private String buildWorkerCommandRaylet() { // Config overwrite cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress()); + cmd.addAll(rayConfig.jvmParameters); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); @@ -245,4 +250,22 @@ private void startObjectStore() { startProcess(command, null, "plasma_store"); } + private String buildPythonWorkerCommand() { + // disable python worker start from raylet, which starts from java + if (rayConfig.pythonWorkerCommand == null) { + return ""; + } + + List cmd = new ArrayList<>(); + cmd.add(rayConfig.pythonWorkerCommand); + cmd.add("--node-ip-address=" + rayConfig.nodeIp); + cmd.add("--object-store-name=" + rayConfig.objectStoreSocketName); + cmd.add("--raylet-name=" + rayConfig.rayletSocketName); + cmd.add("--redis-address=" + rayConfig.getRedisAddress()); + + String command = cmd.stream().collect(Collectors.joining(" ")); + LOGGER.debug("python worker command: {}", command); + return command; + } + } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java b/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java index 8a66923e3464a..c6ab5650c038f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java @@ -14,13 +14,16 @@ import org.apache.commons.io.IOUtils; import org.apache.commons.io.filefilter.DirectoryFileFilter; import org.apache.commons.io.filefilter.RegexFileFilter; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * load and unload jars from a dir. */ public class JarLoader { + private static final Logger LOGGER = LoggerFactory.getLogger(JarLoader.class); + public static URLClassLoader loadJars(String dir, boolean explicitLoad) { // get all jars Collection jars = FileUtils.listFiles( @@ -42,7 +45,7 @@ private static URLClassLoader loadJar(Collection appJars, boolean explicit for (File appJar : appJars) { try { - RayLog.core.info("load jar " + appJar.getAbsolutePath()); + LOGGER.info("succeeded to load jar {}.", appJar.getAbsolutePath()); JarFile jar = new JarFile(appJar.getAbsolutePath()); jars.add(jar); urls.add(appJar.toURI().toURL()); diff --git a/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java index 98cc436312423..4863ca5d13c1d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java @@ -2,59 +2,11 @@ import java.util.HashMap; import java.util.Map; -import org.ray.api.annotation.RayRemote; -import org.ray.api.annotation.ResourceItem; public class ResourceUtil { public static final String CPU_LITERAL = "CPU"; public static final String GPU_LITERAL = "GPU"; - /** - * Convert the array that contains resource items to a map. - * - * @param remoteAnnotation The RayRemote annotation that contains the resource items. - * @return The map whose key represents the resource name - * and the value represents the resource quantity. - */ - public static Map getResourcesMapFromArray(RayRemote remoteAnnotation) { - Map resourceMap = new HashMap<>(); - if (remoteAnnotation != null) { - for (ResourceItem item : remoteAnnotation.resources()) { - if (!item.name().isEmpty()) { - resourceMap.put(item.name(), item.value()); - } - } - } - if (!resourceMap.containsKey(CPU_LITERAL)) { - resourceMap.put(CPU_LITERAL, 0.0); - } - return resourceMap; - } - - /** - * Convert the resources map to a format string. - * - * @param resources The resource map to be Converted. - * @return The format resources string, like "{CPU:4, GPU:0}". - */ - public static String getResourcesFromatStringFromMap(Map resources) { - if (resources == null) { - return "{}"; - } - StringBuilder builder = new StringBuilder(); - builder.append("{"); - int count = 1; - for (Map.Entry entry : resources.entrySet()) { - builder.append(entry.getKey()).append(":").append(entry.getValue()); - count++; - if (count != resources.size()) { - builder.append(", "); - } - } - builder.append("}"); - return builder.toString(); - } - /** * Convert resources map to a string that is used * for the command line argument of starting raylet. @@ -99,7 +51,7 @@ public static Map getResourcesMapFromString(String resources) String[] resourcePair = trimItem.split(":"); if (resourcePair.length != 2) { - throw new IllegalArgumentException("Format of static resurces configure is invalid."); + throw new IllegalArgumentException("Format of static resources configure is invalid."); } final String resourceName = resourcePair[0].trim(); diff --git a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java b/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java similarity index 81% rename from java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java rename to java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java index 52d9a7359247a..d7b347945792c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java @@ -3,6 +3,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; +import java.util.List; + import org.ray.api.id.UniqueId; @@ -11,7 +13,7 @@ * Note: any changes to these methods must be synced with C++ helper functions * in src/ray/id.h */ -public class UniqueIdHelper { +public class UniqueIdUtil { public static final int OBJECT_INDEX_POS = 0; public static final int OBJECT_INDEX_LENGTH = 4; @@ -37,7 +39,7 @@ private static UniqueId computeObjectId(UniqueId taskId, int index) { System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH); ByteBuffer wbb = ByteBuffer.wrap(objId); wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index); + wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index); return new UniqueId(objId); } @@ -63,9 +65,18 @@ public static UniqueId computePutId(UniqueId taskId, int putIndex) { public static UniqueId computeTaskId(UniqueId objectId) { byte[] taskId = new byte[UniqueId.LENGTH]; System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH); - Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS, - UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0); + Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS, + UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0); return new UniqueId(taskId); } + + public static byte[][] getIdBytes(List objectIds) { + int size = objectIds.size(); + byte[][] ids = new byte[size][]; + for (int i = 0; i < size; i++) { + ids[i] = objectIds.get(i).getBytes(); + } + return ids; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java b/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java deleted file mode 100644 index 99bc0912e1d07..0000000000000 --- a/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.ray.runtime.util.exception; - -/** - * An exception which is thrown when a ray task encounters an error when executing. - */ -public class TaskExecutionException extends RuntimeException { - - public TaskExecutionException(Throwable cause) { - super(cause); - } - - public TaskExecutionException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java index 10ffc3488f287..82fdf6b7f99e1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java @@ -21,7 +21,17 @@ private String build() { newLine(""); newLine("package org.ray.api;"); newLine(""); - newLine("import org.ray.api.function.*;"); + newLine("import org.ray.api.function.RayFunc;"); + newLine("import org.ray.api.function.RayFunc0;"); + newLine("import org.ray.api.function.RayFunc1;"); + newLine("import org.ray.api.function.RayFunc2;"); + newLine("import org.ray.api.function.RayFunc3;"); + newLine("import org.ray.api.function.RayFunc4;"); + newLine("import org.ray.api.function.RayFunc5;"); + newLine("import org.ray.api.function.RayFunc6;"); + newLine("import org.ray.api.options.ActorCreationOptions;"); + newLine("import org.ray.api.options.BaseTaskOptions;"); + newLine("import org.ray.api.options.CallOptions;"); newLine(""); newLine("/**"); @@ -33,19 +43,21 @@ private String build() { newLine(1, "// Methods for remote function invocation."); newLine(1, "// ======================================="); for (int i = 0; i <= MAX_PARAMETERS; i++) { - buildCalls(i, false, false); + buildCalls(i, false, false, false); + buildCalls(i, false, false, true); } newLine(1, "// ==========================================="); newLine(1, "// Methods for remote actor method invocation."); newLine(1, "// ==========================================="); for (int i = 0; i <= MAX_PARAMETERS - 1; i++) { - buildCalls(i, true, false); + buildCalls(i, true, false, false); } newLine(1, "// ==========================="); newLine(1, "// Methods for actor creation."); newLine(1, "// ==========================="); for (int i = 0; i <= MAX_PARAMETERS; i++) { - buildCalls(i, false, true); + buildCalls(i, false, true, false); + buildCalls(i, false, true, true); } newLine("}"); return sb.toString(); @@ -57,7 +69,8 @@ private String build() { * @param forActor build actor api when true, otherwise build task api. * @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`. */ - private void buildCalls(int numParameters, boolean forActor, boolean forActorCreation) { + private void buildCalls(int numParameters, boolean forActor, + boolean forActorCreation, boolean hasOptionsParam) { String genericTypes = ""; String argList = ""; for (int i = 0; i < numParameters; i++) { @@ -82,18 +95,36 @@ private void buildCalls(int numParameters, boolean forActor, boolean forActorCre paramPrefix += ", "; } + String optionsParam; + if (hasOptionsParam) { + optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options"; + } else { + optionsParam = ""; + } + + String optionsArg; + if (forActor) { + optionsArg = ""; + } else { + if (hasOptionsParam) { + optionsArg = ", options"; + } else { + optionsArg = ", null"; + } + } + String returnType = !forActorCreation ? "RayObject" : "RayActor"; String funcName = !forActorCreation ? "call" : "createActor"; String funcArgs = !forActor ? "f, args" : "f, actor, args"; for (String param : generateParameters(0, numParameters)) { // method signature newLine(1, String.format( - "public static <%s> %s %s(%s) {", - genericTypes, returnType, funcName, paramPrefix + param + "public static <%s> %s %s(%s%s) {", + genericTypes, returnType, funcName, paramPrefix + param, optionsParam )); // method body newLine(2, String.format("Object[] args = new Object[]{%s};", argList)); - newLine(2, String.format("return Ray.internal().%s(%s);", funcName, funcArgs)); + newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg)); newLine(1, "}"); } } diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index c20d679a9c598..b45d7dc6376d4 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -25,9 +25,16 @@ ray { // Available resources on this node, for example "CPU:4,GPU:0". resources: "" - // If worker.mode is DRIVER, specify the driver id. - // If not provided, a random id will be used. - driver.id: "" + // Configuration items about driver. + driver { + // If worker.mode is DRIVER, specify the driver id. + // If not provided, a random id will be used. + id: "" + // If this config is set, worker will use different paths to loadresources when + // executing tasks from different drivers. E.g. if it's set to '/tm/driver_resources', + // the path for driver 123 will be '/tmp/driver_resources/123'. + resource-path: "" + } // Root dir of log files. log-dir: /tmp/ray/logs @@ -36,6 +43,9 @@ ray { // Otherwise, output will be printed to console. redirect-output: true + // Custom worker jvm parameters. + worker.jvm-parameters: [] + // Custom `java.library.path` // Note, do not use `dir1:dir2` format, put each dir as a list item. library.path: [] @@ -76,4 +86,5 @@ ray { // RPC socket name of Raylet socket-name: /tmp/ray/sockets/raylet } + } diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 85f482544c84d..f5ff1e481a36e 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -1,5 +1,9 @@ package org.ray.runtime.functionmanager; +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -41,8 +45,6 @@ public Object bar() { private static FunctionDescriptor barDescriptor; private static FunctionDescriptor barConstructorDescriptor; - private FunctionManager functionManager; - @BeforeClass public static void beforeClass() { fooFunc = FunctionManagerTest::foo; @@ -57,13 +59,9 @@ public static void beforeClass() { "()V"); } - @Before - public void before() { - functionManager = new FunctionManager(); - } - @Test public void testGetFunctionFromRayFunc() { + final FunctionManager functionManager = new FunctionManager(null); // Test normal function. RayFunction func = functionManager.getFunction(UniqueId.NIL, fooFunc); Assert.assertFalse(func.isConstructor()); @@ -74,7 +72,7 @@ public void testGetFunctionFromRayFunc() { func = functionManager.getFunction(UniqueId.NIL, barFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); - Assert.assertNotNull(func.getRayRemoteAnnotation()); + Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor func = functionManager.getFunction(UniqueId.NIL, barConstructor); @@ -85,6 +83,7 @@ public void testGetFunctionFromRayFunc() { @Test public void testGetFunctionFromFunctionDescriptor() { + final FunctionManager functionManager = new FunctionManager(null); // Test normal function. RayFunction func = functionManager.getFunction(UniqueId.NIL, fooDescriptor); Assert.assertFalse(func.isConstructor()); @@ -95,7 +94,7 @@ public void testGetFunctionFromFunctionDescriptor() { func = functionManager.getFunction(UniqueId.NIL, barDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); - Assert.assertNotNull(func.getRayRemoteAnnotation()); + Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor func = functionManager.getFunction(UniqueId.NIL, barConstructorDescriptor); @@ -116,4 +115,28 @@ public void testLoadFunctionTableForClass() { Assert.assertTrue(res.containsKey( ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor))); } + + //TODO(qwang): This is an integration test case, and we should move it to test folder in the future. + @Test + public void testGetFunctionFromLocalResource() throws Exception{ + UniqueId driverId = UniqueId.fromHexString("0123456789012345678901234567890123456789"); + + //TODO(qwang): We should use a independent app demo instead of `tutorial`. + final String resourcePath = "/tmp/ray/test/resource"; + final String srcJarPath = System.getProperty("user.dir") + + "/../tutorial/target/ray-tutorial-0.1-SNAPSHOT.jar"; + final String destJarPath = resourcePath + "/" + driverId.toString() + + "/ray-tutorial-0.1-SNAPSHOT.jar"; + + File file = new File(resourcePath + "/" + driverId.toString()); + file.mkdirs(); + Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING); + + final FunctionManager functionManager = new FunctionManager(resourcePath); + FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02", + "sayHello", "()Ljava/lang/String;"); + RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor); + } + } diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index fd47e15ab494b..71e3d0dfff8e7 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -11,6 +11,7 @@ public class RayConfigTest { @Test public void testCreateRayConfig() { System.setProperty("ray.home", "/path/to/ray"); + System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path"); RayConfig rayConfig = RayConfig.create(); Assert.assertEquals("/path/to/ray", rayConfig.rayHome); @@ -19,8 +20,12 @@ public void testCreateRayConfig() { System.setProperty("ray.home", ""); rayConfig = RayConfig.create(); + Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome); Assert.assertEquals(System.getProperty("user.dir") + - "/build/src/common/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); + "/build/src/ray/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); + + Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); + } } diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index 001723fec4bf6..e185a5f19a894 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -1,6 +1,8 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import jdk.nashorn.internal.ir.annotations.Immutable; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,7 +11,8 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; -import org.ray.api.annotation.ResourceItem; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; /** * Resources Management Test. @@ -17,29 +20,13 @@ @RunWith(MyRunner.class) public class ResourcesManagementTest { - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), - @ResourceItem(name = "GPU", value = 0)}) - public static Integer echo1(Integer number) { + @RayRemote + public static Integer echo(Integer number) { return number; } - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), - @ResourceItem(name = "GPU", value = 2)}) - public static Integer echo2(Integer number) { - return number; - } - - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 2), - @ResourceItem(name = "GPU", value = 0)}) - public static class Echo1 { - public Integer echo(Integer number) { - return number; - } - } - - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 8), - @ResourceItem(name = "GPU", value = 0)}) - public static class Echo2 { + @RayRemote + public static class Echo { public Integer echo(Integer number) { return number; } @@ -47,12 +34,18 @@ public Integer echo(Integer number) { @Test public void testMethods() { + CallOptions callOptions1 = new CallOptions(ImmutableMap.of("CPU", 4.0, "GPU", 0.0)); + // This is a case that can satisfy required resources. - RayObject result1 = Ray.call(ResourcesManagementTest::echo1, 100); + // The static resources for test are "CPU:4,RES-A:4". + RayObject result1 = Ray.call(ResourcesManagementTest::echo, 100, callOptions1); Assert.assertEquals(100, (int) result1.get()); + CallOptions callOptions2 = new CallOptions(ImmutableMap.of("CPU", 4.0, "GPU", 2.0)); + // This is a case that can't satisfy required resources. - final RayObject result2 = Ray.call(ResourcesManagementTest::echo2, 200); + // The static resources for test are "CPU:4,RES-A:4". + final RayObject result2 = Ray.call(ResourcesManagementTest::echo, 200, callOptions2); WaitResult waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000); Assert.assertEquals(0, waitResult.getReady().size()); @@ -61,14 +54,24 @@ public void testMethods() { @Test public void testActors() { + + ActorCreationOptions actorCreationOptions1 = + new ActorCreationOptions(ImmutableMap.of("CPU", 2.0, "GPU", 0.0)); + // This is a case that can satisfy required resources. - RayActor echo1 = Ray.createActor(Echo1::new); - final RayObject result1 = Ray.call(Echo1::echo, echo1, 100); + // The static resources for test are "CPU:4,RES-A:4". + RayActor echo1 = Ray.createActor(Echo::new, actorCreationOptions1); + final RayObject result1 = Ray.call(Echo::echo, echo1, 100); Assert.assertEquals(100, (int) result1.get()); // This is a case that can't satisfy required resources. - RayActor echo2 = Ray.createActor(Echo2::new); - final RayObject result2 = Ray.call(Echo2::echo, echo2, 100); + // The static resources for test are "CPU:4,RES-A:4". + ActorCreationOptions actorCreationOptions2 = + new ActorCreationOptions(ImmutableMap.of("CPU", 8.0, "GPU", 0.0)); + + RayActor echo2 = + Ray.createActor(Echo::new, actorCreationOptions2); + final RayObject result2 = Ray.call(Echo::echo, echo2, 100); WaitResult waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000); Assert.assertEquals(0, waitResult.getReady().size()); diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java new file mode 100644 index 0000000000000..4fab74aed1991 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -0,0 +1,98 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.id.UniqueId; + +@RunWith(MyRunner.class) +public class StressTest { + + public static int echo(int x) { + return x; + } + + @Test + public void testSubmittingTasks() { + for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) { + int numTasks = 1000 / numIterations; + for (int i = 0; i < numIterations; i++) { + List resultIds = new ArrayList<>(); + for (int j = 0; j < numTasks; j++) { + resultIds.add(Ray.call(StressTest::echo, 1).getId()); + } + for (Integer result : Ray.get(resultIds)) { + Assert.assertEquals(result, Integer.valueOf(1)); + } + } + } + } + + @Test + public void testDependency() { + RayObject x = Ray.call(StressTest::echo, 1); + for (int i = 0; i < 1000; i++) { + x = Ray.call(StressTest::echo, x); + } + Assert.assertEquals(x.get(), Integer.valueOf(1)); + } + + public static class Actor { + + public int ping() { + return 1; + } + } + + public static class Worker { + + private RayActor actor; + + public Worker(RayActor actor) { + this.actor = actor; + } + + public int ping(int n) { + List objectIds = new ArrayList<>(); + for (int i = 0; i < n; i++) { + objectIds.add(Ray.call(Actor::ping, actor).getId()); + } + int sum = 0; + for (Integer result : Ray.get(objectIds)) { + sum += result; + } + return sum; + } + } + + @Test + public void testSubmittingManyTasksToOneActor() { + RayActor actor = Ray.createActor(Actor::new); + List objectIds = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + RayActor worker = Ray.createActor(Worker::new, actor); + objectIds.add(Ray.call(Worker::ping, worker, 100).getId()); + } + for (Integer result : Ray.get(objectIds)) { + Assert.assertEquals(result, Integer.valueOf(100)); + } + } + + @Test + public void testPuttingAndGettingManyObjects() { + Integer objectToPut = 1; + List> objects = new ArrayList<>(); + for (int i = 0; i < 100_000; i++) { + objects.add(Ray.put(objectToPut)); + } + for (RayObject object : objects) { + Assert.assertEquals(object.get(), objectToPut); + } + } +} diff --git a/java/test/src/main/java/org/ray/api/test/TestListener.java b/java/test/src/main/java/org/ray/api/test/TestListener.java index 3fb16bf4f379f..efc419b34720e 100644 --- a/java/test/src/main/java/org/ray/api/test/TestListener.java +++ b/java/test/src/main/java/org/ray/api/test/TestListener.java @@ -10,7 +10,7 @@ public class TestListener extends RunListener { @Override public void testRunStarted(Description description) { System.setProperty("ray.home", "../.."); - System.setProperty("ray.resources", "CPU:4"); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); Ray.init(); } diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 0a21fc2872bf3..2fd47057d90de 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -5,18 +5,16 @@ import javax.xml.bind.DatatypeConverter; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.id.UniqueId; -import org.ray.runtime.util.UniqueIdHelper; +import org.ray.runtime.util.UniqueIdUtil; -@RunWith(MyRunner.class) public class UniqueIdTest { @Test public void testConstructUniqueId() { // Test `fromHexString()` UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", id1.toString()); + Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString()); Assert.assertFalse(id1.isNil()); try { @@ -40,12 +38,12 @@ public void testConstructUniqueId() { ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer); Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); - Assert.assertEquals("0123456789ABCDEF0123456789ABCDEF01234567", id4.toString()); + Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString()); // Test `genNil()` UniqueId id6 = UniqueId.genNil(); - Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", id6.toString()); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } @@ -54,19 +52,19 @@ public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - UniqueId returnId = UniqueIdHelper.computeReturnId(taskId, 1); - Assert.assertEquals("01000000123456789ABCDEF123456789ABCDEF00", returnId.toString()); + UniqueId returnId = UniqueIdUtil.computeReturnId(taskId, 1); + Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); - returnId = UniqueIdHelper.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("04030201123456789ABCDEF123456789ABCDEF00", returnId.toString()); + returnId = UniqueIdUtil.computeReturnId(taskId, 0x01020304); + Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); } @Test public void testComputeTaskId() { UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); - UniqueId taskId = UniqueIdHelper.computeTaskId(objId); + UniqueId taskId = UniqueIdUtil.computeTaskId(objId); - Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", taskId.toString()); + Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); } @Test @@ -74,11 +72,11 @@ public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - UniqueId putId = UniqueIdHelper.computePutId(taskId, 1); - Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00", putId.toString()); + UniqueId putId = UniqueIdUtil.computePutId(taskId, 1); + Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); - putId = UniqueIdHelper.computePutId(taskId, 0x01020304); - Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00", putId.toString()); + putId = UniqueIdUtil.computePutId(taskId, 0x01020304); + Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); } } diff --git a/java/tutorial/pom.xml b/java/tutorial/pom.xml index 198f6f0a3a51e..48a03dc1ca8e1 100644 --- a/java/tutorial/pom.xml +++ b/java/tutorial/pom.xml @@ -40,7 +40,7 @@ ${basedir}/../ray.config.ini -ea - -Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/local_scheduler + -Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/ray/raylet -noverify -DlogOutput=console diff --git a/python/benchmarks/benchmark_actor.py b/python/benchmarks/benchmark_actor.py index b0450c14de6a6..2eb476e1f172f 100644 --- a/python/benchmarks/benchmark_actor.py +++ b/python/benchmarks/benchmark_actor.py @@ -9,7 +9,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=NUM_WORKERS, num_cpus=4) + ray.init(num_cpus=4) setup.is_initialized = True diff --git a/python/benchmarks/benchmark_get.py b/python/benchmarks/benchmark_get.py index 27a848e9cf3c6..fccfc00e0f709 100644 --- a/python/benchmarks/benchmark_get.py +++ b/python/benchmarks/benchmark_get.py @@ -9,7 +9,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=4, num_cpus=4) + ray.init(num_cpus=4) setup.is_initialized = True diff --git a/python/benchmarks/benchmark_put.py b/python/benchmarks/benchmark_put.py index 986a28c89f283..e74bf099666ac 100644 --- a/python/benchmarks/benchmark_put.py +++ b/python/benchmarks/benchmark_put.py @@ -9,7 +9,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=4, num_cpus=4) + ray.init(num_cpus=0) setup.is_initialized = True diff --git a/python/benchmarks/benchmark_queue.py b/python/benchmarks/benchmark_queue.py index bc4ec6a41ee5d..fd8a4a6eb13a8 100644 --- a/python/benchmarks/benchmark_queue.py +++ b/python/benchmarks/benchmark_queue.py @@ -8,7 +8,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=4, num_cpus=4) + ray.init(num_cpus=4) setup.is_initialized = True diff --git a/python/benchmarks/benchmark_task.py b/python/benchmarks/benchmark_task.py index 30a4bb8cb1f23..b454f63277fc5 100644 --- a/python/benchmarks/benchmark_task.py +++ b/python/benchmarks/benchmark_task.py @@ -7,7 +7,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=10, num_cpus=10, resources={"foo": 1}) + ray.init(num_cpus=10, resources={"foo": 1}) setup.is_initialized = True diff --git a/python/benchmarks/benchmark_wait.py b/python/benchmarks/benchmark_wait.py index b40c0463a6cbe..614d76a38c54f 100644 --- a/python/benchmarks/benchmark_wait.py +++ b/python/benchmarks/benchmark_wait.py @@ -9,7 +9,7 @@ def setup(*args): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=4, num_cpus=4) + ray.init(num_cpus=4) setup.is_initialized = True diff --git a/python/benchmarks/benchmarks.py b/python/benchmarks/benchmarks.py index 6eac996a48a55..c286e1ef6ec0f 100644 --- a/python/benchmarks/benchmarks.py +++ b/python/benchmarks/benchmarks.py @@ -7,7 +7,7 @@ def setup(): if not hasattr(setup, "is_initialized"): - ray.init(num_workers=4, num_cpus=4) + ray.init(num_cpus=4) setup.is_initialized = True diff --git a/python/build-wheel-macos.sh b/python/build-wheel-macos.sh index 588362e8099e2..30e8b19363769 100755 --- a/python/build-wheel-macos.sh +++ b/python/build-wheel-macos.sh @@ -16,15 +16,24 @@ DOWNLOAD_DIR=python_downloads PY_VERSIONS=("2.7.13" "3.4.4" "3.5.3" - "3.6.1") + "3.6.1" + "3.7.0") PY_INSTS=("python-2.7.13-macosx10.6.pkg" "python-3.4.4-macosx10.6.pkg" "python-3.5.3-macosx10.6.pkg" - "python-3.6.1-macosx10.6.pkg") + "python-3.6.1-macosx10.6.pkg" + "python-3.7.0-macosx10.6.pkg") PY_MMS=("2.7" "3.4" "3.5" - "3.6") + "3.6" + "3.7") +# On python 3.7, a newer version of numpy seems to be necessary. +NUMPY_VERSIONS=("1.10.4" + "1.10.4" + "1.10.4" + "1.10.4" + "1.14.5") mkdir -p $DOWNLOAD_DIR mkdir -p .whl @@ -33,6 +42,7 @@ for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do PY_VERSION=${PY_VERSIONS[i]} PY_INST=${PY_INSTS[i]} PY_MM=${PY_MMS[i]} + NUMPY_VERSION=${NUMPY_VERSIONS[i]} # The -f flag is passed twice to also run git clean in the arrow subdirectory. # The -d flag removes directories. The -x flag ignores the .gitignore file, @@ -60,7 +70,7 @@ for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do $PIP_CMD install -q setuptools_scm==2.1.0 # Fix the numpy version because this will be the oldest numpy version we can # support. - $PIP_CMD install -q numpy==1.10.4 cython==0.27.3 + $PIP_CMD install -q numpy==$NUMPY_VERSION cython==0.27.3 # Install wheel to avoid the error "invalid command 'bdist_wheel'". $PIP_CMD install -q wheel # Add the correct Python to the path and build the wheel. This is only diff --git a/python/build-wheel-manylinux1.sh b/python/build-wheel-manylinux1.sh index 8fdee4a1a480e..db31ff55a4e6e 100755 --- a/python/build-wheel-manylinux1.sh +++ b/python/build-wheel-manylinux1.sh @@ -13,7 +13,7 @@ rm -f /usr/bin/python2 ln -s /opt/python/cp27-cp27m/bin/python2 /usr/bin/python2 mkdir .whl -for PYTHON in cp27-cp27mu cp34-cp34m cp35-cp35m cp36-cp36m; do +for PYTHON in cp27-cp27mu cp34-cp34m cp35-cp35m cp36-cp36m cp37-cp37m; do # The -f flag is passed twice to also run git clean in the arrow subdirectory. # The -d flag removes directories. The -x flag ignores the .gitignore file, # and the -e flag ensures that we don't remove the .whl directory. diff --git a/python/ray/WebUI.ipynb b/python/ray/WebUI.ipynb index 390263827e037..229366eba10b6 100644 --- a/python/ray/WebUI.ipynb +++ b/python/ray/WebUI.ipynb @@ -1,150 +1,97 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Ray UI\n", - "\n", - "Start the UI with **Kernel -> Restart and Run All**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import ray\n", - "import ray.experimental.ui as ui\n", - "\n", - "ray.init(redis_address=os.environ[\"REDIS_ADDRESS\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Object search." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.object_search_bar()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Task search." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.task_search_bar()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Task trace timeline." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To view arrows, go to View Options and select Flow Events." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.task_timeline()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Task durations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.task_completion_time_distribution()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### CPU usage." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.cpu_usage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Cluster usage." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ui.cluster_usage()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.1" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [{ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ray UI\n", "\n", + "Start the UI with **Kernel -> Restart and Run All**." + ] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", "import ray\n", + "import ray.experimental.ui as ui\n", "\n", + "ray.init(redis_address=os.environ[\"REDIS_ADDRESS\"])" + ] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": ["#### Task trace timeline."] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To view arrows, go to View Options and select Flow Events." + ] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": ["ui.task_timeline()"] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": ["#### Object transfer timeline."] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": ["ui.object_transfer_timeline()"] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": ["#### Task durations."] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": ["ui.task_completion_time_distribution()"] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": ["#### CPU usage."] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": ["ui.cpu_usage()"] + }, { + "cell_type": "markdown", + "metadata": {}, + "source": ["#### Cluster usage."] + }, { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": ["ui.cluster_usage()"] + }], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/python/ray/__init__.py b/python/ray/__init__.py index b97af4b587daa..ed024a107aa50 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -46,7 +46,10 @@ e.args += (helpful_message, ) raise -from ray.local_scheduler import ObjectID, _config # noqa: E402 +modin_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "modin") +sys.path.insert(0, modin_path) + +from ray.raylet import ObjectID, _config # noqa: E402 from ray.profiling import profile # noqa: E402 from ray.worker import (error_info, init, connect, disconnect, get, put, wait, remote, get_gpu_ids, get_resource_ids, get_webui_url, @@ -61,9 +64,8 @@ import ray.actor # noqa: F401 from ray.actor import method # noqa: E402 -# Ray version string. TODO(rkn): This is also defined separately in setup.py. -# Fix this. -__version__ = "0.5.3" +# Ray version string. +__version__ = "0.6.0" __all__ = [ "error_info", "init", "connect", "disconnect", "get", "put", "wait", diff --git a/python/ray/actor.py b/python/ray/actor.py index 3886e1927a02f..926f15b293644 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,29 +5,21 @@ import copy import hashlib import inspect -import json +import logging +import sys import traceback import ray.cloudpickle as pickle -import ray.local_scheduler +from ray.function_manager import FunctionActorManager +import ray.raylet import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import ( - decode, - _random_string, - check_oversized_pickle, - is_cython, - push_error_to_driver, -) +from ray.utils import _random_string DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 - -def is_classmethod(f): - """Returns whether the given method is a classmethod.""" - - return hasattr(f, "__self__") and f.__self__ is not None +logger = logging.getLogger(__name__) def compute_actor_handle_id(actor_handle_id, num_forks): @@ -96,24 +88,6 @@ def compute_actor_creation_function_id(class_id): return ray.ObjectID(class_id) -def compute_actor_method_function_id(class_name, attr): - """Get the function ID corresponding to an actor method. - - Args: - class_name (str): The class name of the actor. - attr (str): The attribute name of the method. - - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(class_name.encode("ascii")) - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return ray.ObjectID(function_id) - - def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, frontier): """Set the most recent checkpoint associated with a given actor ID. @@ -134,28 +108,6 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, }) -def get_actor_checkpoint(worker, actor_id): - """Get the most recent checkpoint associated with a given actor ID. - - Args: - worker: The worker to use to get the checkpoint. - actor_id: The actor ID of the actor to get the checkpoint for. - - Returns: - If a checkpoint exists, this returns a tuple of the number of tasks - included in the checkpoint, the saved checkpoint state, and the - task frontier at the time of the checkpoint. If no checkpoint - exists, all objects are set to None. The checkpoint index is the . - executed on the actor before the checkpoint was made. - """ - actor_key = b"Actor:" + actor_id - checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( - actor_key, ["checkpoint_index", "checkpoint", "frontier"]) - if checkpoint_index is not None: - checkpoint_index = int(checkpoint_index) - return checkpoint_index, checkpoint, frontier - - def save_and_log_checkpoint(worker, actor): """Save a checkpoint on the actor and log any errors. @@ -205,219 +157,26 @@ def restore_and_log_checkpoint(worker, actor): return checkpoint_resumed -def make_actor_method_executor(worker, method_name, method, actor_imported): - """Make an executor that wraps a user-defined actor method. - - The wrapped method updates the worker's internal state and performs any - necessary checkpointing operations. +def get_actor_checkpoint(worker, actor_id): + """Get the most recent checkpoint associated with a given actor ID. Args: - worker (Worker): The worker that is executing the actor. - method_name (str): The name of the actor method. - method (instancemethod): The actor method to wrap. This should be a - method defined on the actor class and should therefore take an - instance of the actor as the first argument. - actor_imported (bool): Whether the actor has been imported. - Checkpointing operations will not be run if this is set to False. + worker: The worker to use to get the checkpoint. + actor_id: The actor ID of the actor to get the checkpoint for. Returns: - A function that executes the given actor method on the worker's stored - instance of the actor. The function also updates the worker's - internal state to record the executed method. - """ - - def actor_method_executor(dummy_return_id, actor, *args): - # Update the actor's task counter to reflect the task we're about to - # execute. - worker.actor_task_counter += 1 - - # If this is the first task to execute on the actor, try to resume from - # a checkpoint. - if actor_imported and worker.actor_task_counter == 1: - checkpoint_resumed = restore_and_log_checkpoint(worker, actor) - if checkpoint_resumed: - # NOTE(swang): Since we did not actually execute the __init__ - # method, this will put None as the return value. If the - # __init__ method is supposed to return multiple values, an - # exception will be logged. - return - - # Determine whether we should checkpoint the actor. - checkpointing_on = (actor_imported - and worker.actor_checkpoint_interval > 0) - # We should checkpoint the actor if user checkpointing is on, we've - # executed checkpoint_interval tasks since the last checkpoint, and the - # method we're about to execute is not a checkpoint. - save_checkpoint = ( - checkpointing_on and - (worker.actor_task_counter % worker.actor_checkpoint_interval == 0 - and method_name != "__ray_checkpoint__")) - - # Execute the assigned method and save a checkpoint if necessary. - try: - if is_classmethod(method): - method_returns = method(*args) - else: - method_returns = method(actor, *args) - except Exception: - # Save the checkpoint before allowing the method exception to be - # thrown. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - raise - else: - # Save the checkpoint before returning the method's return values. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - return method_returns - - return actor_method_executor - - -def fetch_and_register_actor(actor_class_key, worker): - """Import an actor. - - This will be called by the worker's import thread when the worker receives - the actor_class export, assuming that the worker is an actor for that - class. - - Args: - actor_class_key: The key in Redis to use to fetch the actor. - worker: The worker to use. - """ - actor_id_str = worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, actor_method_names) = worker.redis_client.hmget( - actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", - "checkpoint_interval", "actor_method_names" - ]) - - class_name = decode(class_name) - module = decode(module) - checkpoint_interval = int(checkpoint_interval) - actor_method_names = json.loads(decode(actor_method_names)) - - # Create a temporary actor with some temporary methods so that if the actor - # fails to be unpickled, the temporary actor can be used (just to produce - # error messages and to prevent the driver from hanging). - class TemporaryActor(object): - pass - - worker.actors[actor_id_str] = TemporaryActor() - worker.actor_checkpoint_interval = checkpoint_interval - - def temporary_actor_method(*xs): - raise Exception("The actor with name {} failed to be imported, and so " - "cannot execute this method".format(class_name)) - - # Register the actor method executors. - for actor_method_name in actor_method_names: - function_id = compute_actor_method_function_id(class_name, - actor_method_name).id() - temporary_executor = make_actor_method_executor( - worker, - actor_method_name, - temporary_actor_method, - actor_imported=False) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=temporary_executor, - function_name=actor_method_name, - max_calls=0)) - worker.num_task_executions[driver_id][function_id] = 0 - - try: - unpickled_class = pickle.loads(pickled_class) - worker.actor_class = unpickled_class - except Exception: - # If an exception was thrown when the actor was imported, we record the - # traceback and notify the scheduler of the failure. - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - # Log the error message. - push_error_to_driver( - worker, - ray_constants.REGISTER_ACTOR_PUSH_ERROR, - traceback_str, - driver_id, - data={"actor_id": actor_id_str}) - # TODO(rkn): In the future, it might make sense to have the worker exit - # here. However, currently that would lead to hanging if someone calls - # ray.get on a method invoked on the actor. - else: - # TODO(pcm): Why is the below line necessary? - unpickled_class.__module__ = module - worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) - - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - - actor_methods = inspect.getmembers(unpickled_class, predicate=pred) - for actor_method_name, actor_method in actor_methods: - function_id = compute_actor_method_function_id( - class_name, actor_method_name).id() - executor = make_actor_method_executor( - worker, actor_method_name, actor_method, actor_imported=True) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=executor, - function_name=actor_method_name, - max_calls=0)) - # We do not set worker.function_properties[driver_id][function_id] - # because we currently do need the actor worker to submit new tasks - # for the actor. - - -def publish_actor_class_to_key(key, actor_class_info, worker): - """Push an actor class definition to Redis. - - The is factored out as a separate function because it is also called - on cached actor class definitions when a worker connects for the first - time. - - Args: - key: The key to store the actor class info at. - actor_class_info: Information about the actor class. - worker: The worker to use to connect to Redis. + If a checkpoint exists, this returns a tuple of the number of tasks + included in the checkpoint, the saved checkpoint state, and the + task frontier at the time of the checkpoint. If no checkpoint + exists, all objects are set to None. The checkpoint index is the . + executed on the actor before the checkpoint was made. """ - # We set the driver ID here because it may not have been available when the - # actor class was defined. - actor_class_info["driver_id"] = worker.task_driver_id.id() - worker.redis_client.hmset(key, actor_class_info) - worker.redis_client.rpush("Exports", key) - - -def export_actor_class(class_id, Class, actor_method_names, - checkpoint_interval, worker): - key = b"ActorClass:" + class_id - actor_class_info = { - "class_name": Class.__name__, - "module": Class.__module__, - "class": pickle.dumps(Class), - "checkpoint_interval": checkpoint_interval, - "actor_method_names": json.dumps(list(actor_method_names)) - } - - check_oversized_pickle(actor_class_info["class"], - actor_class_info["class_name"], "actor", worker) - - if worker.mode is None: - # This means that 'ray.init()' has not been called yet and so we must - # cache the actor class definition and export it when 'ray.init()' is - # called. - assert worker.cached_remote_functions_and_actors is not None - worker.cached_remote_functions_and_actors.append( - ("actor", (key, actor_class_info))) - # This caching code path is currently not used because we only export - # actor class definitions lazily when we instantiate the actor for the - # first time. - assert False, "This should be unreachable." - else: - publish_actor_class_to_key(key, actor_class_info, worker) - # TODO(rkn): Currently we allow actor classes to be defined within tasks. - # I tried to disable this, but it may be necessary because of - # https://github.com/ray-project/ray/issues/1146. + actor_key = b"Actor:" + actor_id + checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( + actor_key, ["checkpoint_index", "checkpoint", "frontier"]) + if checkpoint_index is not None: + checkpoint_index = int(checkpoint_index) + return checkpoint_index, checkpoint, frontier def method(*args, **kwargs): @@ -466,9 +225,15 @@ def __call__(self, *args, **kwargs): self._method_name)) def remote(self, *args, **kwargs): - return self._submit(args, kwargs) + return self._remote(args, kwargs) def _submit(self, args, kwargs, num_return_vals=None): + logger.warn( + "WARNING: _submit() is being deprecated. Please use _remote().") + return self._remote( + args=args, kwargs=kwargs, num_return_vals=num_return_vals) + + def _remote(self, args, kwargs, num_return_vals=None): if num_return_vals is None: num_return_vals = self._num_return_vals @@ -518,13 +283,8 @@ def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, self._actor_method_cpus = actor_method_cpus self._exported = False - # Get the actor methods of the given class. - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - self._actor_methods = inspect.getmembers( - self._modified_class, predicate=pred) + self._modified_class, ray.utils.is_function_or_method) # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. @@ -537,7 +297,7 @@ def pred(x): # don't support, there may not be much the user can do about it. signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( - method, ignore_first=not is_classmethod(method)) + method, ignore_first=not ray.utils.is_class_method(method)) # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): @@ -568,7 +328,7 @@ def remote(self, *args, **kwargs): Returns: A handle to the newly created actor. """ - return self._submit(args=args, kwargs=kwargs) + return self._remote(args=args, kwargs=kwargs) def _submit(self, args, @@ -576,6 +336,21 @@ def _submit(self, num_cpus=None, num_gpus=None, resources=None): + logger.warn( + "WARNING: _submit() is being deprecated. Please use _remote().") + return self._remote( + args=args, + kwargs=kwargs, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=resources) + + def _remote(self, + args, + kwargs, + num_cpus=None, + num_gpus=None, + resources=None): """Create an actor. This method allows more flexibility than the remote method because @@ -614,15 +389,24 @@ def _submit(self, else: # Export the actor. if not self._exported: - export_actor_class(self._class_id, self._modified_class, - self._actor_method_names, - self._checkpoint_interval, worker) + worker.function_actor_manager.export_actor_class( + self._class_id, self._modified_class, + self._actor_method_names, self._checkpoint_interval) self._exported = True resources = ray.utils.resources_from_resource_arguments( self._num_cpus, self._num_gpus, self._resources, num_cpus, num_gpus, resources) + # If the actor methods require CPU resources, then set the required + # placement resources. If actor_placement_resources is empty, then + # the required placement resources will be the same as resources. + actor_placement_resources = {} + assert self._actor_method_cpus in [0, 1] + if self._actor_method_cpus == 1: + actor_placement_resources = resources.copy() + actor_placement_resources["CPU"] += 1 + creation_args = [self._class_id] function_id = compute_actor_creation_function_id(self._class_id) [actor_cursor] = worker.submit_task( @@ -630,7 +414,8 @@ def _submit(self, creation_args, actor_creation_id=actor_id, num_return_vals=1, - resources=resources) + resources=resources, + placement_resources=actor_placement_resources) # We initialize the actor counter at 1 to account for the actor # creation task. @@ -741,6 +526,7 @@ def __init__(self, self._ray_actor_method_cpus = actor_method_cpus self._ray_actor_driver_id = actor_driver_id self._ray_previous_actor_handle_id = previous_actor_handle_id + self._ray_previously_generated_actor_handle_id = None def _actor_method_call(self, method_name, @@ -794,15 +580,27 @@ def _actor_method_call(self, is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") + # Right now, if the actor handle has been pickled, we create a + # temporary actor handle id for invocations. + # TODO(pcm): This still leads to a lot of actor handles being + # created, there should be a better way to handle pickled + # actor handles. if self._ray_actor_handle_id is None: actor_handle_id = compute_actor_handle_id_non_forked( self._ray_actor_id, self._ray_previous_actor_handle_id, worker.current_task_id) + # Each new task creates a new actor handle id, so we need to + # reset the actor counter to 0 + if (actor_handle_id != + self._ray_previously_generated_actor_handle_id): + self._ray_actor_counter = 0 + self._ray_previously_generated_actor_handle_id = ( + actor_handle_id) else: actor_handle_id = self._ray_actor_handle_id - function_id = compute_actor_method_function_id(self._ray_class_name, - method_name) + function_id = FunctionActorManager.compute_actor_method_function_id( + self._ray_class_name, method_name) object_ids = worker.submit_task( function_id, args, @@ -816,6 +614,7 @@ def _actor_method_call(self, # We add one for the dummy return ID. num_return_vals=num_return_vals + 1, resources={"CPU": self._ray_actor_method_cpus}, + placement_resources={}, driver_id=self._ray_actor_driver_id) # Update the actor counter and cursor to reflect the most recent # invocation. @@ -983,8 +782,8 @@ def __ray_terminate__(self): # this is so that when the worker kills itself below, the local # scheduler won't push an error message to the driver. worker.local_scheduler_client.disconnect() - import os - os._exit(0) + sys.exit(0) + assert False, "This process should have terminated." def __ray_save_checkpoint__(self): if hasattr(self, "__ray_save__"): @@ -1068,5 +867,4 @@ def __ray_checkpoint_restore__(self): resources, actor_method_cpus) -ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor ray.worker.global_worker.make_actor = make_actor diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 2a3734ba403c3..9c4a452ee2687 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -69,6 +69,7 @@ "project_id": (None, OPTIONAL), # gcp project id, if using gcp "head_ip": (str, OPTIONAL), # local cluster head node "worker_ips": (list, OPTIONAL), # local cluster worker nodes + "use_internal_ips": (bool, OPTIONAL), # don't require public ips }, REQUIRED), @@ -490,8 +491,10 @@ def files_up_to_date(self, node_id): def recover_if_needed(self, node_id): if not self.can_update(node_id): return - last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip.get( - self.provider.internal_ip(node_id), 0) + key = self.provider.internal_ip(node_id) + if key not in self.load_metrics.last_heartbeat_time_by_ip: + self.load_metrics.last_heartbeat_time_by_ip[key] = time.time() + last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key] delta = time.time() - last_heartbeat_time if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: return diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 8e5d3a4daffc5..62e0b25ee2e2d 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -10,6 +10,7 @@ import boto3 from botocore.config import Config +import botocore from ray.ray_constants import BOTO_MAX_RETRIES @@ -114,7 +115,8 @@ def _configure_key_pair(config): ec2 = _resource("ec2", config) # Try a few times to get or create a good key pair. - for i in range(10): + MAX_NUM_KEYS = 20 + for i in range(MAX_NUM_KEYS): key_name, key_path = key_pair(i, config["provider"]["region"]) key = _get_key(key_name, config) @@ -131,7 +133,12 @@ def _configure_key_pair(config): os.chmod(key_path, 0o600) break - assert key, "AWS keypair {} not found for {}".format(key_name, key_path) + if not key: + raise ValueError( + "No matching local key file for any of the key pairs in this " + "account with ids from 0..{}. ".format(key_name) + + "Consider deleting some unused keys pairs from your account.") + assert os.path.exists(key_path), \ "Private key file {} not found for {}".format(key_path, key_name) @@ -146,9 +153,10 @@ def _configure_key_pair(config): def _configure_subnet(config): ec2 = _resource("ec2", config) + use_internal_ips = config["provider"].get("use_internal_ips", False) subnets = sorted( - (s for s in ec2.subnets.all() - if s.state == "available" and s.map_public_ip_on_launch), + (s for s in ec2.subnets.all() if s.state == "available" and ( + use_internal_ips or s.map_public_ip_on_launch)), reverse=True, # sort from Z-A key=lambda subnet: subnet.availability_zone) if not subnets: @@ -156,7 +164,8 @@ def _configure_subnet(config): "No usable subnets found, try manually creating an instance in " "your specified region to populate the list of subnets " "and trying this again. Note that the subnet must map public IPs " - "on instance launch.") + "on instance launch unless you set 'use_internal_ips': True in " + "the 'provider' config.") if "availability_zone" in config["provider"]: azs = config["provider"]["availability_zone"].split(',') subnets = [s for s in subnets if s.availability_zone in azs] @@ -264,7 +273,7 @@ def _get_role(role_name, config): try: role.load() return role - except Exception: + except botocore.errorfactory.NoSuchEntityException: return None @@ -274,7 +283,7 @@ def _get_instance_profile(profile_name, config): try: profile.load() return profile - except Exception: + except botocore.errorfactory.NoSuchEntityException: return None diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index 55691863fffb1..d74d45823c211 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -89,9 +89,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 91d3e4b6b0713..9c63725296df9 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -189,7 +189,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, logger.info("Head node up-to-date, IP address is: {}".format( provider.external_ip(head_node))) - monitor_str = "tail -n 100 -f /tmp/raylogs/monitor-*" + monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*" for s in init_commands: if ("ray start" in s and "docker exec" in s and "--autoscaling-config" in s): diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index e9a95e8543be4..6afbb464fa6a0 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -124,9 +124,9 @@ setup_commands: pip install google-api-python-client==1.6.7 cython==0.27.3 - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl - >- cd ~ && git clone https://github.com/ray-project/ray || true diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index b132971d2fc53..1d6b5e23b3840 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -47,7 +47,8 @@ def __init__(self, self.daemon = True self.process_runner = process_runner self.node_id = node_id - self.use_internal_ip = use_internal_ip + self.use_internal_ip = (use_internal_ip or provider_config.get( + "use_internal_ips", False)) self.provider = get_node_provider(provider_config, cluster_name) self.ssh_private_key = auth_config["ssh_private_key"] self.ssh_user = auth_config["ssh_user"] diff --git a/python/ray/common/redis_module/.gitkeep b/python/ray/common/redis_module/.gitkeep deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py deleted file mode 100644 index 7a7d25c6bedc0..0000000000000 --- a/python/ray/common/redis_module/runtest.py +++ /dev/null @@ -1,451 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import redis -import sys -import time -import unittest - -import ray.gcs_utils -import ray.services - - -def integerToAsciiHex(num, numbytes): - retstr = b"" - # Support 32 and 64 bit architecture. - assert (numbytes == 4 or numbytes == 8) - for i in range(numbytes): - curbyte = num & 0xff - if sys.version_info >= (3, 0): - retstr += bytes([curbyte]) - else: - retstr += chr(curbyte) - num = num >> 8 - - return retstr - - -def get_next_message(pubsub_client, timeout_seconds=10): - """Block until the next message is available on the pubsub channel.""" - start_time = time.time() - while True: - message = pubsub_client.get_message() - if message is not None: - return message - time.sleep(0.1) - if time.time() - start_time > timeout_seconds: - raise Exception("Timed out while waiting for next message.") - - -class TestGlobalStateStore(unittest.TestCase): - def setUp(self): - unused_primary_redis_addr, redis_shards = ray.services.start_redis( - "localhost", use_credis="RAY_USE_NEW_GCS" in os.environ) - self.redis = redis.StrictRedis( - host="localhost", port=redis_shards[0].split(":")[-1], db=0) - - def tearDown(self): - ray.services.cleanup() - - def testInvalidObjectTableAdd(self): - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called - # with the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - "one", "hash2", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, - "hash2", "manager_id1", - "extra argument") - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an - # object ID that is already present with a different hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1"}) - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - # Check that the second manager was added, even though the hash was - # mismatched. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that it is fine if we add the same object ID multiple times - # with the most recent hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, - "hash2", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - - def testObjectTableAddAndLookup(self): - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not - # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Add a manager that already exists again and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that we properly handle NULL characters. In the past, NULL - # characters were handled improperly causing a "hash mismatch" error if - # two object IDs that agreed up to the NULL character were inserted - # with different hashes. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, - "hash2", "manager_id1") - # Check that NULL characters in the hash are handled properly. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash1", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash2", "manager_id1") - - def testObjectTableAddAndRemove(self): - # Try removing a manager from an object ID that has not been added yet. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not - # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that doesn't exist, and make sure we still have the - # same set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that does exist. Make sure it gets removed the first - # time and does nothing the second time. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - # Remove the last manager, and make sure we have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) - # Remove a manager from an empty set, and make sure we now have an - # empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) - - def testObjectTableSubscribeToNotifications(self): - # Define a helper method for checking the contents of object - # notifications. - def check_object_notification(notification_message, object_id, - object_size, manager_ids): - notification_object = (ray.gcs_utils.SubscribeToNotificationsReply. - GetRootAsSubscribeToNotificationsReply( - notification_message, 0)) - self.assertEqual(notification_object.ObjectId(), object_id) - self.assertEqual(notification_object.ObjectSize(), object_size) - self.assertEqual(notification_object.ManagerIdsLength(), - len(manager_ids)) - for i in range(len(manager_ids)): - self.assertEqual( - notification_object.ManagerIds(i), manager_ids[i]) - - data_size = 0xf1f0 - p = self.redis.pubsub() - # Subscribe to an object ID. - p.psubscribe("{}manager_id1".format( - ray.gcs_utils.OBJECT_CHANNEL_PREFIX)) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", - data_size, "hash1", "manager_id2") - # Receive the acknowledgement message. - self.assertEqual(get_next_message(p)["data"], 1) - # Request a notification and receive the data. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id1") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id1", data_size, - [b"manager_id2"]) - - # Request a notification for an object that isn't there. Then add the - # object and receive the data. Only the first call to - # RAY.OBJECT_TABLE_ADD should trigger notifications. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id2", "object_id3") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1"]) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - data_size, "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id2", data_size, - [b"manager_id3"]) - # Request notifications for object_id3 again. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1", b"manager_id2", b"manager_id3"]) - - def testResultTableAddAndLookup(self): - def check_result_table_entry(message, task_id, is_put): - result_table_reply = ( - ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( - message, 0)) - self.assertEqual(result_table_reply.TaskId(), task_id) - self.assertEqual(result_table_reply.IsPut(), is_put) - - # Try looking up something in the result table before anything is - # added. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Adding the object to the object table should have no effect. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Add the result to the result table. The lookup now returns the task - # ID. - task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", - task_id, 0) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Doing it again should still work. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Try another result table lookup. This should succeed. - task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", - task_id, 1) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id2") - check_result_table_entry(response, task_id, True) - - def testInvalidTaskTableAdd(self): - # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called - # with the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, - "node_id") - with self.assertRaises(redis.ResponseError): - # Non-integer scheduling states should not be added. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - "invalid_state", "node_id", "task_spec") - with self.assertRaises(redis.ResponseError): - # Should not be able to update a non-existent task. - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, - "node_id", b"") - - def testTaskTableAddAndLookup(self): - TASK_STATUS_WAITING = 1 - TASK_STATUS_SCHEDULED = 2 - TASK_STATUS_QUEUED = 4 - - # make sure somebody will get a notification (checked in the redis - # module) - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - - def check_task_reply(message, task_args, updated=False): - (task_status, local_scheduler_id, execution_dependencies_string, - spillback_count, task_spec) = task_args - task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - message, 0) - self.assertEqual(task_reply_object.State(), task_status) - self.assertEqual(task_reply_object.LocalSchedulerId(), - local_scheduler_id) - self.assertEqual(task_reply_object.SpillbackCount(), - spillback_count) - self.assertEqual(task_reply_object.TaskSpec(), task_spec) - self.assertEqual(task_reply_object.Updated(), updated) - - # Check that task table adds, updates, and lookups work correctly. - task_args = [TASK_STATUS_WAITING, b"node_id", b"", 0, b"task_spec"] - response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - *task_args) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) - - task_args[0] = TASK_STATUS_SCHEDULED - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", - *task_args[:4]) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) - - # If the current value, test value, and set value are all the same, the - # update happens, and the response is still the same task. - task_args = [task_args[0]] + task_args - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the task entry is still the same. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - check_task_reply(get_response, task_args[1:]) - - # If the current value is the same as the test value, and the set value - # is different, the update happens, and the response is the entire - # task. - task_args[1] = TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the update happened. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - check_task_reply(get_response, task_args[1:]) - - # If the current value is no longer the same as the test value, the - # response is the same task as before the test-and-set. - new_task_args = task_args[:] - new_task_args[1] = TASK_STATUS_WAITING - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - self.assertEqual(get_response2, get_response) - - # If the test value is a bitmask that matches the current value, the - # update happens. - task_args = new_task_args - task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - - # If the test value is a bitmask that does not match the current value, - # the update does not happen, and the response is the same task as - # before the test-and-set. - new_task_args = task_args[:] - new_task_args[0] = TASK_STATUS_SCHEDULED - old_response = response - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - self.assertNotEqual(get_response, old_response) - check_task_reply(get_response, task_args[1:]) - - def check_task_subscription(self, p, scheduling_state, local_scheduler_id): - task_args = [ - b"task_id", scheduling_state, - local_scheduler_id.encode("ascii"), b"", 0, b"task_spec" - ] - self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) - # Receive the data. - message = get_next_message(p)["data"] - # Check that the notification object is correct. - notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - message, 0) - self.assertEqual(notification_object.TaskId(), task_args[0]) - self.assertEqual(notification_object.State(), task_args[1]) - self.assertEqual(notification_object.LocalSchedulerId(), task_args[2]) - self.assertEqual(notification_object.ExecutionDependencies(), - task_args[3]) - self.assertEqual(notification_object.TaskSpec(), task_args[-1]) - - def testTaskTableSubscribe(self): - scheduling_state = 1 - local_scheduler_id = "local_scheduler_id" - # Subscribe to the task table. - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - # unsubscribe to make sure there is only one subscriber at a given time - p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - p.psubscribe("{prefix}*:{state}".format( - prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}*:{state}".format( - prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - p.psubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=ray.gcs_utils.TASK_PREFIX, - local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=ray.gcs_utils.TASK_PREFIX, - local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py deleted file mode 100644 index cd36b697bbaad..0000000000000 --- a/python/ray/common/test/test.py +++ /dev/null @@ -1,181 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import pickle -import sys -import unittest - -import ray.local_scheduler as local_scheduler -import ray.ray_constants as ray_constants - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -BASE_SIMPLE_OBJECTS = [ - 0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"", - 990 * u"h", - np.ones(3), - np.array([True, False]), None, True, False -] - -if sys.version_info < (3, 0): - BASE_SIMPLE_OBJECTS += [ - long(0), # noqa: E501,F821 - long(1), # noqa: E501,F821 - long(100000), # noqa: E501,F821 - long(1 << 100) # noqa: E501,F821 - ] - -LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS] -TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS] -DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS] - -SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS + - TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS) - -# Create some complex objects that cannot be serialized by value in tasks. - -lst = [] -lst.append(lst) - - -class Foo(object): - def __init__(self): - pass - - -BASE_COMPLEX_OBJECTS = [ - 15000 * "h", 15000 * u"h", lst, - Foo(), 100 * [100 * [10 * [1]]], - np.array([Foo()]) -] - -LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS] -TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS] -DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS] - -COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS + - TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS) - - -class TestSerialization(unittest.TestCase): - def test_serialize_by_value(self): - - for val in SIMPLE_OBJECTS: - self.assertTrue(local_scheduler.check_simple_value(val)) - for val in COMPLEX_OBJECTS: - self.assertFalse(local_scheduler.check_simple_value(val)) - - -class TestObjectID(unittest.TestCase): - def test_create_object_id(self): - random_object_id() - - def test_cannot_pickle_object_ids(self): - object_ids = [random_object_id() for _ in range(256)] - - def f(): - return object_ids - - def g(val=object_ids): - return 1 - - def h(): - object_ids[0] - return 1 - - # Make sure that object IDs cannot be pickled (including functions that - # close over object IDs). - self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0])) - self.assertRaises(Exception, lambda: pickle.dumps(object_ids)) - self.assertRaises(Exception, lambda: pickle.dumps(f)) - self.assertRaises(Exception, lambda: pickle.dumps(g)) - self.assertRaises(Exception, lambda: pickle.dumps(h)) - - def test_equality_comparisons(self): - x1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a") - x2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a") - y1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b") - y2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b") - self.assertEqual(x1, x2) - self.assertEqual(y1, y2) - self.assertNotEqual(x1, y1) - - random_strings = [ - np.random.bytes(ray_constants.ID_SIZE) for _ in range(256) - ] - object_ids1 = [ - local_scheduler.ObjectID(random_strings[i]) for i in range(256) - ] - object_ids2 = [ - local_scheduler.ObjectID(random_strings[i]) for i in range(256) - ] - self.assertEqual(len(set(object_ids1)), 256) - self.assertEqual(len(set(object_ids1 + object_ids2)), 256) - self.assertEqual(set(object_ids1), set(object_ids2)) - - def test_hashability(self): - x = random_object_id() - y = random_object_id() - {x: y} - {x, y} - - -class TestTask(unittest.TestCase): - def check_task(self, task, function_id, num_return_vals, args): - self.assertEqual(function_id.id(), task.function_id().id()) - retrieved_args = task.arguments() - self.assertEqual(num_return_vals, len(task.returns())) - self.assertEqual(len(args), len(retrieved_args)) - for i in range(len(retrieved_args)): - if isinstance(retrieved_args[i], local_scheduler.ObjectID): - self.assertEqual(retrieved_args[i].id(), args[i].id()) - else: - self.assertEqual(retrieved_args[i], args[i]) - - def test_create_and_serialize_task(self): - # TODO(rkn): The function ID should be a FunctionID object, not an - # ObjectID. - driver_id = random_driver_id() - parent_id = random_task_id() - function_id = random_function_id() - object_ids = [random_object_id() for _ in range(256)] - args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"], - 10 * ["a"], 100 * ["a"], 1000 * ["a"], [ - 1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2] - ], object_ids[:1], object_ids[:2], object_ids[:3], - object_ids[:4], object_ids[:5], object_ids[:10], - object_ids[:100], object_ids[:256], [1, object_ids[0]], [ - object_ids[0], "a" - ], [1, object_ids[0], "a"], [ - object_ids[0], 1, object_ids[1], "a" - ], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids] - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(driver_id, function_id, args, - num_return_vals, parent_id, 0) - self.check_task(task, function_id, num_return_vals, args) - data = local_scheduler.task_to_string(task) - task2 = local_scheduler.task_from_string(data) - self.check_task(task2, function_id, num_return_vals, args) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/python/ray/common/thirdparty/redis/src/.gitkeep b/python/ray/common/thirdparty/redis/src/.gitkeep deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/ray/common/__init__.py b/python/ray/core/src/ray/__init__.py similarity index 100% rename from python/ray/common/__init__.py rename to python/ray/core/src/ray/__init__.py diff --git a/python/ray/core/src/local_scheduler/__init__.py b/python/ray/core/src/ray/raylet/__init__.py similarity index 100% rename from python/ray/core/src/local_scheduler/__init__.py rename to python/ray/core/src/ray/raylet/__init__.py diff --git a/python/ray/experimental/async_api.py b/python/ray/experimental/async_api.py new file mode 100644 index 0000000000000..8df8596e29aa5 --- /dev/null +++ b/python/ray/experimental/async_api.py @@ -0,0 +1,62 @@ +# Note: asyncio is only compatible with Python 3 + +import asyncio +import ray +from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler + +handler = None +transport = None +protocol = None + + +async def _async_init(): + global handler, transport, protocol + if handler is None: + worker = ray.worker.global_worker + loop = asyncio.get_event_loop() + worker.plasma_client.subscribe() + rsock = worker.plasma_client.get_notification_socket() + handler = PlasmaEventHandler(loop, worker) + transport, protocol = await loop.create_connection( + lambda: PlasmaProtocol(worker.plasma_client, handler), sock=rsock) + + +def init(): + """ + Initialize synchronously. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise Exception("You must initialize the Ray async API by calling " + "async_api.init() or async_api.as_future(obj) before " + "the event loop starts.") + else: + asyncio.get_event_loop().run_until_complete(_async_init()) + + +def as_future(object_id): + """Turn an object_id into a Future object. + + Args: + object_id: A Ray object_id. + + Returns: + PlasmaObjectFuture: A future object that waits the object_id. + """ + if handler is None: + init() + return handler.as_future(object_id) + + +def shutdown(): + """Manually shutdown the async API. + + Cancels all related tasks and all the socket transportation. + """ + global handler, transport, protocol + if handler is not None: + handler.close() + transport.close() + handler = None + transport = None + protocol = None diff --git a/python/ray/experimental/async_plasma.py b/python/ray/experimental/async_plasma.py new file mode 100644 index 0000000000000..2c0f806f2467b --- /dev/null +++ b/python/ray/experimental/async_plasma.py @@ -0,0 +1,237 @@ +import asyncio +import ctypes +import sys + +import pyarrow.plasma as plasma + +import ray +from ray.services import logger + +INT64_SIZE = ctypes.sizeof(ctypes.c_int64) + + +def _release_waiter(waiter, *_): + if not waiter.done(): + waiter.set_result(None) + + +class PlasmaProtocol(asyncio.Protocol): + """Protocol control for the asyncio connection.""" + + def __init__(self, plasma_client, plasma_event_handler): + self.plasma_client = plasma_client + self.plasma_event_handler = plasma_event_handler + self.transport = None + self._buffer = b"" + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + self._buffer += data + messages = [] + i = 0 + while i + INT64_SIZE <= len(self._buffer): + msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE], + sys.byteorder) + if i + INT64_SIZE + msg_len > len(self._buffer): + break + i += INT64_SIZE + segment = self._buffer[i:i + msg_len] + i += msg_len + messages.append(self.plasma_client.decode_notification(segment)) + + self._buffer = self._buffer[i:] + self.plasma_event_handler.process_notifications(messages) + + def connection_lost(self, exc): + # The socket has been closed + logger.debug("PlasmaProtocol - connection lost.") + + def eof_received(self): + logger.debug("PlasmaProtocol - EOF received.") + self.transport.close() + + +class PlasmaObjectFuture(asyncio.Future): + """This class manages the lifecycle of a Future contains an object_id. + + Note: + This Future is an item in an linked list. + + Attributes: + object_id: The object_id this Future contains. + """ + + def __init__(self, loop, object_id): + super().__init__(loop=loop) + self.object_id = object_id + self.prev = None + self.next = None + + @property + def ray_object_id(self): + return ray.ObjectID(self.object_id.binary()) + + def __repr__(self): + return super().__repr__() + "{object_id=%s}" % self.object_id + + +class PlasmaObjectLinkedList(asyncio.Future): + """This class is a doubly-linked list. + It holds a ObjectID and maintains futures assigned to the ObjectID. + + Args: + loop: an event loop. + plain_object_id (plasma.ObjectID): + The plasma ObjectID this class holds. + """ + + def __init__(self, loop, plain_object_id): + super().__init__(loop=loop) + assert isinstance(plain_object_id, plasma.ObjectID) + self.object_id = plain_object_id + self.head = None + self.tail = None + + def append(self, future): + """Append an object to the linked list. + + Args: + future (PlasmaObjectFuture): A PlasmaObjectFuture instance. + """ + future.prev = self.tail + if self.tail is None: + assert self.head is None + self.head = future + else: + self.tail.next = future + self.tail = future + # Once done, it will be removed from the list. + future.add_done_callback(self.remove) + + def remove(self, future): + """Remove an object from the linked list. + + Args: + future (PlasmaObjectFuture): A PlasmaObjectFuture instance. + """ + if self._loop.get_debug(): + logger.debug("Removing %s from the linked list.", future) + if future.prev is None: + assert future is self.head + self.head = future.next + if self.head is None: + self.tail = None + if not self.cancelled(): + self.set_result(None) + else: + self.head.prev = None + elif future.next is None: + assert future is self.tail + self.tail = future.prev + if self.tail is None: + self.head = None + if not self.cancelled(): + self.set_result(None) + else: + self.tail.prev = None + + def cancel(self, *args, **kwargs): + """Manually cancel all tasks assigned to this event loop.""" + # Because remove all futures will trigger `set_result`, + # we cancel itself first. + super().cancel() + for future in self.traverse(): + # All cancelled futures should have callbacks to removed itself + # from this linked list. However, these callbacks are scheduled in + # an event loop, so we could still find them in our list. + if not future.cancelled(): + future.cancel() + + def set_result(self, result): + """Complete all tasks. """ + for future in self.traverse(): + # All cancelled futures should have callbacks to removed itself + # from this linked list. However, these callbacks are scheduled in + # an event loop, so we could still find them in our list. + future.set_result(result) + if not self.done(): + super().set_result(result) + + def traverse(self): + """Traverse this linked list. + + Yields: + PlasmaObjectFuture: PlasmaObjectFuture instances. + """ + current = self.head + while current is not None: + yield current + current = current.next + + +class PlasmaEventHandler: + """This class is an event handler for Plasma.""" + + def __init__(self, loop, worker): + super().__init__() + self._loop = loop + self._worker = worker + self._waiting_dict = {} + + def process_notifications(self, messages): + """Process notifications.""" + for object_id, object_size, metadata_size in messages: + if object_size > 0 and object_id in self._waiting_dict: + linked_list = self._waiting_dict[object_id] + self._complete_future(linked_list) + + def close(self): + """Clean up this handler.""" + for linked_list in self._waiting_dict.values(): + linked_list.cancel() + # All cancelled linked lists should have callbacks to removed itself + # from the waiting dict. However, these callbacks are scheduled in + # an event loop, so we don't check them now. + + def _unregister_callback(self, fut): + del self._waiting_dict[fut.object_id] + + def _complete_future(self, fut): + obj = self._worker.retrieve_and_deserialize([fut.object_id], 0)[0] + fut.set_result(obj) + + def as_future(self, object_id, check_ready=True): + """Turn an object_id into a Future object. + + Args: + object_id: A Ray's object_id. + check_ready (bool): If true, check if the object_id is ready. + + Returns: + PlasmaObjectFuture: A future object that waits the object_id. + """ + if not isinstance(object_id, ray.ObjectID): + raise TypeError("Input should be an ObjectID.") + + plain_object_id = plasma.ObjectID(object_id.id()) + fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id) + + if check_ready: + ready, _ = ray.wait([object_id], timeout=0) + if ready: + if self._loop.get_debug(): + logger.debug("%s has been ready.", plain_object_id) + self._complete_future(fut) + return fut + + if plain_object_id not in self._waiting_dict: + linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id) + linked_list.add_done_callback(self._unregister_callback) + self._waiting_dict[plain_object_id] = linked_list + self._waiting_dict[plain_object_id].append(fut) + if self._loop.get_debug(): + logger.debug("%s added to the waiting list.", fut) + + return fut diff --git a/python/ray/experimental/sgd/__init__.py b/python/ray/experimental/sgd/__init__.py index e69de29bb2d1d..005b3fff0c1f0 100644 --- a/python/ray/experimental/sgd/__init__.py +++ b/python/ray/experimental/sgd/__init__.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.experimental.sgd.sgd import DistributedSGD +from ray.experimental.sgd.model import Model + +__all__ = [ + "DistributedSGD", + "Model", +] diff --git a/python/ray/experimental/sgd/mnist_example.py b/python/ray/experimental/sgd/mnist_example.py new file mode 100755 index 0000000000000..8c2fff213c94b --- /dev/null +++ b/python/ray/experimental/sgd/mnist_example.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +"""Example of how to train a model with Ray SGD. + +We use a small model here, so no speedup for distributing the computation is +expected. This example shows: + - How to set up a simple input pipeline + - How to evaluate model accuracy during training + - How to get and set model weights + - How to train with ray.experimental.sgd.DistributedSGD +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import time + +from tensorflow.examples.tutorials.mnist import input_data +import tensorflow as tf + +import ray +from ray.tune import run_experiments +from ray.tune.examples.tune_mnist_ray import deepnn +from ray.experimental.sgd.model import Model +from ray.experimental.sgd.sgd import DistributedSGD +from ray.experimental.tfutils import TensorFlowVariables + +parser = argparse.ArgumentParser() +parser.add_argument("--redis-address", default=None, type=str) +parser.add_argument("--num-iters", default=10000, type=int) +parser.add_argument("--batch-size", default=50, type=int) +parser.add_argument("--num-workers", default=1, type=int) +parser.add_argument("--devices-per-worker", default=1, type=int) +parser.add_argument("--tune", action="store_true", help="Run in Ray Tune") +parser.add_argument( + "--strategy", default="ps", type=str, help="One of 'simple' or 'ps'") +parser.add_argument( + "--gpu", action="store_true", help="Use GPUs for optimization") + + +class MNISTModel(Model): + def __init__(self): + # Import data + error = None + for _ in range(10): + try: + self.mnist = input_data.read_data_sets( + "/tmp/tensorflow/mnist/input_data", one_hot=True) + error = None + break + except Exception as e: + error = e + time.sleep(5) + if error: + raise ValueError("Failed to import data", error) + + # Set seed and build layers + tf.set_random_seed(0) + self.x = tf.placeholder(tf.float32, [None, 784], name="x") + self.y_ = tf.placeholder(tf.float32, [None, 10], name="y_") + y_conv, self.keep_prob = deepnn(self.x) + + # Need to define loss and optimizer attributes + self.loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + labels=self.y_, logits=y_conv)) + self.optimizer = tf.train.AdamOptimizer(1e-4) + self.variables = TensorFlowVariables(self.loss, + tf.get_default_session()) + + # For evaluating test accuracy + correct_prediction = tf.equal( + tf.argmax(y_conv, 1), tf.argmax(self.y_, 1)) + self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + + def get_feed_dict(self): + batch = self.mnist.train.next_batch(50) + return { + self.x: batch[0], + self.y_: batch[1], + self.keep_prob: 0.5, + } + + def test_accuracy(self): + return self.accuracy.eval( + feed_dict={ + self.x: self.mnist.test.images, + self.y_: self.mnist.test.labels, + self.keep_prob: 1.0, + }) + + +def train_mnist(config, reporter): + args = config["args"] + sgd = DistributedSGD( + lambda w_i, d_i: MNISTModel(), + num_workers=args.num_workers, + devices_per_worker=args.devices_per_worker, + gpu=args.gpu, + strategy=args.strategy) + + # Important: synchronize the initial weights of all model replicas + w0 = sgd.for_model(lambda m: m.variables.get_flat()) + sgd.foreach_model(lambda m: m.variables.set_flat(w0)) + + for i in range(args.num_iters): + if i % 10 == 0: + start = time.time() + loss = sgd.step(fetch_stats=True)["loss"] + acc = sgd.foreach_model(lambda model: model.test_accuracy()) + print("Iter", i, "loss", loss, "accuracy", acc) + print("Time per iteration", time.time() - start) + assert len(set(acc)) == 1, ("Models out of sync", acc) + reporter(timesteps_total=i, mean_loss=loss, mean_accuracy=acc[0]) + else: + sgd.step() + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init(redis_address=args.redis_address) + + if args.tune: + run_experiments({ + "mnist_sgd": { + "run": train_mnist, + "config": { + "args": args, + }, + }, + }) + else: + train_mnist({"args": args}, lambda **kw: None) diff --git a/python/ray/experimental/sgd/model.py b/python/ray/experimental/sgd/model.py new file mode 100644 index 0000000000000..ac8e0eedf23ea --- /dev/null +++ b/python/ray/experimental/sgd/model.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class Model(object): + """Your class must implement this interface to be used with Ray SGD. + + This supports any form of input pipeline: it is up to you to define it + using TensorFlow. The only requirements are that the loss and optimizer + attributes must be defined. + + For an example implementation, see tfbench/test_model.py + + Attributes: + loss (tf.Tensor): Loss function to minimize. + optimizer (tf.train.Optimizer): Optimizer to use to minimize the loss. + """ + + def get_feed_dict(self): + """Extra values to pass in when computing gradients for the loss. + + Returns: + TensorFlow feed_dict to add to the gradient operation. + """ + return {} diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index a9d6879f99c7b..7c446aa974e15 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -584,7 +584,15 @@ def end_interval(indices, small_ranges, large_indices): if len(small_ranges): new_tower_grads = [] for dev_idx, gv_list in enumerate(tower_grads): - assert len(gv_list) == num_gv + assert len(gv_list) == num_gv, ( + "Possible cause: " + "Networks constructed on different workers " + "don't have the same number of variables. " + "If you use tf.GraphKeys or tf.global_variables() " + "with multiple graphs per worker during network " + "construction, you need to use " + "appropriate scopes, see " + "https://github.com/ray-project/ray/issues/3136") new_gv_list = [] for r in small_ranges: key = '%d:%d' % (dev_idx, len(new_gv_list)) diff --git a/python/ray/experimental/sgd/param_server.py b/python/ray/experimental/sgd/param_server.py new file mode 100644 index 0000000000000..517d419c36440 --- /dev/null +++ b/python/ray/experimental/sgd/param_server.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +import ray +from ray.experimental.sgd.util import Timeline, fetch, warmup + +logger = logging.getLogger(__name__) + + +@ray.remote(num_cpus=0) +class ParameterServer(object): + """Helper class for ray.experimental.sgd.DistributedSGD.""" + + def __init__(self, num_workers, tid): + self.num_sgd_workers = num_workers + self.acc_counter = 0 + self.timeline = Timeline(tid) + # TODO(ekl) get this to work again so we get ray events + # self.timeline.patch_ray() + + def initialize(self, shard_shape): + """Resets the gradient buffer to zeros.""" + self.accumulated = np.zeros(shard_shape, dtype=np.float32) + + def prefetch(self, oids): + """Tell plasma to prefetch the given object ids over the network.""" + self.timeline.reset() + self.timeline.start("prefetch") + fetch(oids) + self.timeline.end("prefetch") + + def add_spinwait(self, grad_shard_ids): + """Optimized version of add() that operates on multiple grads.""" + self.timeline.start("add_spinwait") + plasma_ids = [ray.pyarrow.plasma.ObjectID(x) for x in grad_shard_ids] + while plasma_ids: + for p in plasma_ids: + if ray.worker.global_worker.plasma_client.contains(p): + self.timeline.start("get_buffers") + grads = ray.worker.global_worker.plasma_client.get(p) + self.accumulated += grads + self.acc_counter += 1 + self.timeline.end("get_buffers") + plasma_ids.remove(p) + break + self.timeline.end("add_spinwait") + + def add(self, grad_shard_id): + """Add the given gradient value to the accumulated gradients.""" + self.timeline.start("add") + self.timeline.start("get_buffers") + oid = ray.pyarrow.plasma.ObjectID(grad_shard_id) + grads = ray.worker.global_worker.plasma_client.get(oid) + self.timeline.end("get_buffers") + self.accumulated += grads + self.acc_counter += 1 + self.timeline.end("add") + + def get(self, object_id): + """Put the accumulated gradients to the given object id.""" + self.timeline.start("get") + client = ray.worker.global_worker.plasma_client + assert self.acc_counter == self.num_sgd_workers, self.acc_counter + oid = ray.pyarrow.plasma.ObjectID(object_id) + client.put(self.accumulated.flatten(), object_id=oid) + self.accumulated = np.zeros_like(self.accumulated) + self.acc_counter = 0 + self.timeline.end("get") + + def get_timeline(self): + return self.timeline + + def ip(self): + return ray.services.get_node_ip_address() + + def warmup(self): + warmup() diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index c569c036f1b10..a663960683f79 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -7,338 +7,187 @@ import time import numpy as np -import pyarrow.plasma as plasma -import tensorflow as tf import ray -from ray.experimental.sgd.util import Timeline, fetch, run_timeline -from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \ - unpack_small_tensors +from ray.experimental.sgd.sgd_worker import SGDWorker +from ray.experimental.sgd.param_server import ParameterServer logger = logging.getLogger(__name__) -class SGDWorker(object): +class DistributedSGD(object): + """Experimental distributed SGD implementation in Ray. + + This supports two modes: + 'simple': centralized gradient aggregation + 'ps': sharded parameter-server implementation + + To use this class, you'll have to implement model.py:Model. + + Arguments: + model_creator (func): Function that returns a model given worker and + device indexes as arguments. Each model replica will be created + within its own variable scope. + num_workers (int): Number of Ray actors to use for SGD. + devices_per_worker (int): Number of GPU or CPU devices to use per + worker. One model replica will be created per device. + gpu (bool): Whether to use GPU devices. + strategy (str): Strategy to use for distributed gradient aggregation. + This only applies if num_workers > 1. + grad_shard_bytes (int): Fuse gradient tensors into chunks of at most + this size (if applicable). + all_reduce_alg (str): TensorFlow strategy to use for gradient + synchronization within the same worker (if applicable). + See modified_allreduce.py for options. + + Examples: + >>> # Setup distributed SGD + >>> model_creator = ( + ... lambda worker_idx, device_idx: YourModelClass(...)) + >>> sgd = DistributedSGD( + ... model_creator, num_workers=2, + ... devices_per_worker=4, gpu=True, strategy="ps") + + >>> # To train + >>> for i in range(100): + ... stats = sgd.step(fetch_stats=i % 10 == 0) + + >>> # To access or update model state + >>> sgd.foreach_model(lambda model: ...) + + >>> # To access or update worker state + >>> sgd.foreach_worker(lambda worker: ...) + """ + def __init__(self, - worker_index, model_creator, - all_reduce_alg="simple", - num_devices=1, - use_cpus=False, - max_bytes=60000000, - plasma_op=False): - self.worker_index = worker_index - assert num_devices > 0 - - # TODO(ekl) support custom session - tf_session_args = { - "device_count": { - "CPU": num_devices - }, - "log_device_placement": False, - "gpu_options": tf.GPUOptions(force_gpu_compatible=True), - "inter_op_parallelism_threads": 128, - } - config_proto = tf.ConfigProto(**tf_session_args) - self.sess = tf.Session(config=config_proto) - self.models = [] - grad_ops = [] - - if use_cpus: - device_tmpl = "/cpu:%d" + num_workers, + devices_per_worker, + gpu=True, + strategy="ps", + grad_shard_bytes=10000000, + all_reduce_alg="simple"): + + if num_workers == 1 and strategy == "ps": + logger.warn( + "The parameter server strategy does not make sense for single " + "worker operation, falling back to simple mode.") + strategy = "simple" + + if strategy == "ps": + use_plasma_op = True + elif strategy == "simple": + use_plasma_op = False + grad_shard_bytes = 0 # tensor fusion doesn't make sense else: - device_tmpl = "/gpu:%d" - for device_idx in range(num_devices): - device = device_tmpl % device_idx - with tf.device(device): - with tf.variable_scope("device_%d" % device_idx): - model = model_creator(worker_index, device_idx) - self.models.append(model) - model.grads = [ - t - for t in model.optimizer.compute_gradients(model.loss) - if t[0] is not None - ] - grad_ops.append(model.grads) - - if num_devices == 1: - assert not max_bytes, "Not supported with 1 GPU" - self.packed_grads_and_vars = grad_ops + raise ValueError("strategy must be one of 'ps', 'simple'") + self.strategy = strategy + + self.model_creator = model_creator + if gpu: + requests = {"num_gpus": devices_per_worker} else: - if max_bytes: - self.packed_grads_and_vars, packing_vals = ( - sum_gradients_all_reduce( - "", - grad_ops, - 1, - all_reduce_alg, - 1, - list(range(num_devices)), - agg_small_grads_max_bytes=max_bytes)) - else: - self.packed_grads_and_vars, _ = (sum_gradients_all_reduce( - "", - grad_ops, - 1, - all_reduce_alg, - 1, - list(range(num_devices)), - agg_small_grads_max_bytes=0)) - self.per_device_grads = [ - list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars - ] - assert (len(self.per_device_grads) == num_devices) - self.num_grads = num_grads = len(self.packed_grads_and_vars[0]) - if max_bytes: - logger.info("Packed grads => {} tensors".format(num_grads)) - - # Ops for reading grads with the right control deps - nccl_noops = [] - for j in range(num_grads)[::-1]: - deps = nccl_noops + [ - dev_grad[j] for dev_grad in self.per_device_grads - ] - with tf.control_dependencies(deps): - nccl_noops = [tf.no_op()] - - # You must fetch this otherwise the NCCL allreduce will hang - self.nccl_control_out = tf.group(*nccl_noops) - - round_robin_devices = False - if plasma_op: - store_socket = ( - ray.worker.global_worker.plasma_client.store_socket_name) - manager_socket = ( - ray.worker.global_worker.plasma_client.manager_socket_name) - if not plasma.tf_plasma_op: - plasma.build_plasma_tensorflow_op() - - # For fetching grads -> plasma - self.plasma_in_grads = [] - self.plasma_in_grads_oids = [ - tf.placeholder(shape=[], dtype=tf.string, name="in_grad_oids") - for _ in range(num_grads) - ] - ix = 0 - for j in range(num_grads): - grad = self.per_device_grads[ix][j] - if round_robin_devices: - ix += 1 # round robin assignment - ix %= num_devices - with tf.device(self.models[ix].loss.device): - plasma_grad = plasma.tf_plasma_op.tensor_to_plasma( - [grad], - self.plasma_in_grads_oids[j], - plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) - self.plasma_in_grads.append(plasma_grad) - - # For applying grads <- plasma - unpacked_gv = [] - self.plasma_out_grads_oids = [ - tf.placeholder( - shape=[], dtype=tf.string, name="grad_out_oids") - for _ in range(num_grads) + requests = {"num_cpus": devices_per_worker} + + RemoteSGDWorker = ray.remote(**requests)(SGDWorker) + self.workers = [] + logger.info( + "Creating SGD workers ({} total, {} devices per worker)".format( + num_workers, devices_per_worker)) + for worker_index in range(num_workers): + self.workers.append( + RemoteSGDWorker.remote( + worker_index, + model_creator, + num_devices=devices_per_worker, + plasma_op=use_plasma_op, + gpu=gpu, + max_bytes=grad_shard_bytes, + all_reduce_alg=all_reduce_alg)) + + logger.info("Waiting for gradient configuration") + shard_shapes = ray.get(self.workers[0].shard_shapes.remote()) + + logger.info("Waiting for actors to start") + ray.get([w.shard_shapes.remote() for w in self.workers]) + + if strategy == "ps": + logger.info("Starting parameter servers ({} shards)".format( + len(shard_shapes))) + self.ps_list = [ + ParameterServer.remote(len(self.workers), i) + for i, s in enumerate(shard_shapes) ] - packed_plasma_grads = [] - ix = 0 - for j in range(num_grads): - with tf.device(self.plasma_in_grads[j].device): - with tf.control_dependencies([self.plasma_in_grads[j]]): - grad_ph = plasma.tf_plasma_op.plasma_to_tensor( - self.plasma_out_grads_oids[j], - dtype=tf.float32, - plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) - grad_ph = tf.reshape(grad_ph, - self.packed_grads_and_vars[0][j][0].shape) - logger.debug("Packed tensor {}".format(grad_ph)) - packed_plasma_grads.append(grad_ph) - for i in range(num_devices): - per_device = [] - for j, (g, v) in enumerate(self.packed_grads_and_vars[i]): - grad_ph = packed_plasma_grads[j] - per_device.append((grad_ph, v)) - unpacked_gv.append(per_device) - - if max_bytes: - unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals) - - elif max_bytes: - unpacked_gv = unpack_small_tensors(self.packed_grads_and_vars, - packing_vals) + ray.get([ + ps.initialize.remote(s) + for ps, s in zip(self.ps_list, shard_shapes) + ]) + logger.info("Parameter servers started") else: - unpacked_gv = self.packed_grads_and_vars - - # Same shape as packed_grads_and_vars - assert len(unpacked_gv) == num_devices - assert len(unpacked_gv[0][0]) == 2 - - apply_ops = [] - to_apply = unpacked_gv[0] - for ix, m in enumerate(self.models): - apply_ops.append( - m.optimizer.apply_gradients( - [(g, v) - for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) - self.apply_op = tf.group(*apply_ops) - init_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - self.sess.run(init_op) + self.ps_list = [] + + def foreach_worker(self, fn): + """Apply the given function to each remote worker. + + Returns: + List of results from applying the function. + """ + results = ray.get([w.foreach_worker.remote(fn) for w in self.workers]) + return results def foreach_model(self, fn): - return [fn(m) for m in self.models] + """Apply the given function to each model replica in each worker. - def foreach_worker(self, fn): - return fn(self) - - def compute_gradients(self): - start = time.time() - feed_dict = {} - # Aggregate feed dicts for each model on this worker. - for model in self.models: - feed_dict.update(model.get_feed_dict()) - # We only need to fetch the first per_device_grad, since they are - # averaged across all devices by allreduce. - fetches = self.sess.run( - [ - self.models[0].loss, self.per_device_grads[0], - self.nccl_control_out - ], - feed_dict=feed_dict) - logger.debug( - "compute grad interior time {}".format(time.time() - start)) - return fetches - - def apply_gradients(self, avg_grads): - start = time.time() - result = { - g: avg_grads[i] - for (i, g) in enumerate(self.per_device_grads[0]) - } - self.sess.run(self.apply_op, feed_dict=result) - logger.debug("apply grad interior time {}".format(time.time() - start)) - - def ps_compute_apply(self, - out_grad_shard_oids, - agg_grad_shard_oids, - tl_name="ps_compute_apply", - write_timeline=False): - feed_dict = { - ph: oid - for (ph, - oid) in zip(self.plasma_in_grads_oids, out_grad_shard_oids) - } - feed_dict.update({ - ph: oid - for (ph, - oid) in zip(self.plasma_out_grads_oids, agg_grad_shard_oids) - }) - fetch(agg_grad_shard_oids) - run_timeline( - self.sess, - [self.plasma_in_grads, self.apply_op, self.nccl_control_out], - feed_dict=feed_dict, - write_timeline=write_timeline) - - def num_grad_shards(self): - return self.num_grads - - def shard_shapes(self): - main_gv = self.packed_grads_and_vars[0] - return [g.shape for g, _ in main_gv] - - def ip(self): - return ray.services.get_node_ip_address() - - -class ParameterServer(object): - def __init__(self, num_workers, tid): - self.num_sgd_workers = num_workers - self.acc_counter = 0 - self.timeline = Timeline(tid) - self.timeline.patch_ray() - - def set_tid(self, tid): - self.timeline.tid = tid - - def get_time(self): - return time.time() + self.timeline.offset - - def set_time(self, ref_time): - self.timeline.offset = ref_time - time.time() - - def initialize(self, shard_shape): - self.accumulated = np.zeros(shard_shape, dtype=np.float32) - - def mark(self): - self.timeline.event("mark") - - def prefetch(self, oids): - self.timeline.reset() - self.timeline.start("prefetch") - fetch(oids) - self.timeline.end("prefetch") - - def add_spinwait(self, grad_shard_ids): - self.timeline.start("add_spinwait") - plasma_ids = [ray.pyarrow.plasma.ObjectID(x) for x in grad_shard_ids] - while plasma_ids: - for p in plasma_ids: - if ray.worker.global_worker.plasma_client.contains(p): - self.timeline.start("get_buffers") - grads = ray.worker.global_worker.plasma_client.get(p) - self.accumulated += grads - self.acc_counter += 1 - self.timeline.end("get_buffers") - plasma_ids.remove(p) - break - self.timeline.end("add_spinwait") - - def add(self, grad_shard_id): - self.timeline.start("add") - self.timeline.start("get_buffers") - oid = ray.pyarrow.plasma.ObjectID(grad_shard_id) - grads = ray.worker.global_worker.plasma_client.get(oid) - self.timeline.end("get_buffers") - self.accumulated += grads - self.acc_counter += 1 - self.timeline.end("add") - - def get(self, object_id): - self.timeline.start("get") - client = ray.worker.global_worker.plasma_client - assert self.acc_counter == self.num_sgd_workers, self.acc_counter - oid = ray.pyarrow.plasma.ObjectID(object_id) - client.put(self.accumulate.flatten(), object_id=oid) - self.accumulated = np.zeros_like(self.accumulated) - self.acc_counter = 0 - self.timeline.end("get") - - def get_timeline(self): - return self.timeline - - def ip(self): - return ray.services.get_node_ip_address() - - def pin(self, cpu_id): - try: - import psutil - p = psutil.Process() - p.cpu_affinity([cpu_id]) - logger.info("Setting CPU Affinity to: {}".format(cpu_id)) - except Exception as e: - logger.error(e) - - -def average_gradients(grads): + Returns: + List of results from applying the function. + """ + results = ray.get([w.foreach_model.remote(fn) for w in self.workers]) + out = [] + for r in results: + out.extend(r) + return out + + def for_model(self, fn): + """Apply the given function to a single model replica. + + Returns: + Result from applying the function. + """ + return ray.get(self.workers[0].for_model.remote(fn)) + + def step(self, fetch_stats=False): + """Run a single SGD step. + + Arguments: + fetch_stats (bool): Whether to return stats from the step. This can + slow down the computation by acting as a global barrier. + """ + if self.strategy == "ps": + return _distributed_sgd_step( + self.workers, + self.ps_list, + write_timeline=False, + fetch_stats=fetch_stats) + else: + return _simple_sgd_step(self.workers) + + def warmup(self): + logger.info("Warming up object store of worker actors") + ray.get([w.warmup.remote() for w in self.workers]) + logger.info("Warmup complete") + + +def _average_gradients(grads): out = [] for grad_list in zip(*grads): out.append(np.mean(grad_list, axis=0)) return out -def do_sgd_step(actors): +def _simple_sgd_step(actors): + if len(actors) == 1: + return {"loss": ray.get(actors[0].compute_apply.remote())} + start = time.time() fetches = ray.get([a.compute_gradients.remote() for a in actors]) losses = [f[0] for f in fetches] @@ -349,29 +198,33 @@ def do_sgd_step(actors): assert len(grads) == 1 avg_grad = grads[0] else: - avg_grad = average_gradients(grads) + avg_grad = _average_gradients(grads) logger.debug("grad reduce time {}".format(time.time() - start)) start = time.time() ray.get([a.apply_gradients.remote(avg_grad) for a in actors]) logger.debug("apply all grads time {}".format(time.time() - start)) - return np.mean(losses) + return {"loss": np.mean(losses)} -def distributed_sgd_step(actors, ps_list, write_timeline): +def _distributed_sgd_step(actors, ps_list, fetch_stats, write_timeline): # Preallocate object ids that actors will write gradient shards to grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list] for _ in actors] - logger.info("generated grad oids") + logger.debug("Generated grad oids") # Preallocate object ids that param servers will write new weights to accum_shard_ids = [np.random.bytes(20) for _ in ps_list] - logger.info("generated accum oids") + logger.debug("Generated accum oids") # Kick off the fused compute grad / update weights tf run for each actor + losses = [] for actor, grad_shard_oids in zip(actors, grad_shard_oids_list): - actor.ps_compute_apply.remote( - grad_shard_oids, accum_shard_ids, write_timeline=write_timeline) - logger.info("Launched all ps_compute_applys on all actors") + losses.append( + actor.ps_compute_apply.remote( + grad_shard_oids, + accum_shard_ids, + write_timeline=write_timeline)) + logger.debug("Launched all ps_compute_applys on all actors") # Issue prefetch ops for j, (ps, weight_shard_oid) in list( @@ -381,7 +234,7 @@ def distributed_sgd_step(actors, ps_list, write_timeline): to_fetch.append(grad_shard_oids[j]) random.shuffle(to_fetch) ps.prefetch.remote(to_fetch) - logger.info("Launched all prefetch ops") + logger.debug("Launched all prefetch ops") # Aggregate the gradients produced by the actors. These operations # run concurrently with the actor methods above. @@ -390,11 +243,11 @@ def distributed_sgd_step(actors, ps_list, write_timeline): enumerate(zip(ps_list, accum_shard_ids)))[::-1]: ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list]) ps_gets.append(ps.get.remote(weight_shard_oid)) - logger.info("Launched all aggregate ops") + logger.debug("Launched all aggregate ops") if write_timeline: timelines = [ps.get_timeline.remote() for ps in ps_list] - logger.info("launched timeline gets") + logger.debug("Launched timeline gets") timelines = ray.get(timelines) t0 = timelines[0] for t in timelines[1:]: @@ -403,44 +256,7 @@ def distributed_sgd_step(actors, ps_list, write_timeline): else: # Wait for at least the ps gets to finish ray.get(ps_gets) - - -class DistributedSGD(object): - def __init__(self, - model_creator, - num_workers, - devices_per_worker, - use_cpus=False, - use_plasma_op=False): - self.model_creator = model_creator - if use_cpus: - requests = {"num_cpus": devices_per_worker} - else: - requests = {"num_gpus": devices_per_worker} - RemoteSGDWorker = ray.remote(**requests)(SGDWorker) - self.workers = [] - for worker_index in range(num_workers): - logger.info("Creating worker {}".format(worker_index)) - self.workers.append( - RemoteSGDWorker.remote( - worker_index, - model_creator, - num_devices=devices_per_worker, - plasma_op=use_plasma_op, - use_cpus=use_cpus)) - assert not use_plasma_op, \ - "TODO: when use_plasma_op is true, we must run in PS mode" - - def foreach_worker(self, fn): - results = ray.get([w.foreach_worker.remote(fn) for w in self.workers]) - return results - - def foreach_model(self, fn): - results = ray.get([w.foreach_model.remote(fn) for w in self.workers]) - out = [] - for r in results: - out.extend(r) - return r - - def step(self): - return do_sgd_step(self.workers) + if fetch_stats: + return {"loss": np.mean(ray.get(losses))} + else: + return None diff --git a/python/ray/experimental/sgd/sgd_worker.py b/python/ray/experimental/sgd/sgd_worker.py new file mode 100644 index 0000000000000..0d4b45c7c8bc4 --- /dev/null +++ b/python/ray/experimental/sgd/sgd_worker.py @@ -0,0 +1,268 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import pyarrow.plasma as plasma +import tensorflow as tf + +import ray +from ray.experimental.sgd.util import fetch, run_timeline, warmup +from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \ + unpack_small_tensors + +logger = logging.getLogger(__name__) + + +class SGDWorker(object): + """Helper class for ray.experimental.sgd.DistributedSGD.""" + + def __init__(self, + worker_index, + model_creator, + all_reduce_alg="simple", + num_devices=1, + gpu=False, + max_bytes=10000000, + plasma_op=False): + self.worker_index = worker_index + assert num_devices > 0 + + # TODO(ekl) support custom session + tf_session_args = { + "device_count": { + "CPU": num_devices + }, + "log_device_placement": False, + "gpu_options": tf.GPUOptions(force_gpu_compatible=True), + "inter_op_parallelism_threads": 128, + } + config_proto = tf.ConfigProto(**tf_session_args) + self.sess = tf.Session(config=config_proto) + self.models = [] + grad_ops = [] + + if gpu: + device_tmpl = "/gpu:%d" + else: + device_tmpl = "/cpu:%d" + with self.sess.as_default(): + for device_idx in range(num_devices): + device = device_tmpl % device_idx + with tf.device(device): + with tf.variable_scope("device_%d" % device_idx): + model = model_creator(worker_index, device_idx) + self.models.append(model) + grads = [ + t for t in model.optimizer.compute_gradients( + model.loss) if t[0] is not None + ] + grad_ops.append(grads) + + if num_devices == 1: + if max_bytes: + raise ValueError( + "Implementation limitation: grad_shard_bytes > 0 " + "({}) currently requires > 1 device".format(max_bytes)) + self.packed_grads_and_vars = grad_ops + else: + if max_bytes: + self.packed_grads_and_vars, packing_vals = ( + sum_gradients_all_reduce( + "", + grad_ops, + 1, + all_reduce_alg, + 1, + list(range(num_devices)), + agg_small_grads_max_bytes=max_bytes)) + else: + self.packed_grads_and_vars, _ = (sum_gradients_all_reduce( + "", + grad_ops, + 1, + all_reduce_alg, + 1, + list(range(num_devices)), + agg_small_grads_max_bytes=0)) + self.per_device_grads = [ + list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars + ] + assert (len(self.per_device_grads) == num_devices) + self.num_grads = num_grads = len(self.packed_grads_and_vars[0]) + if max_bytes: + logger.info("Packed grads => {} tensors".format(num_grads)) + + # Ops for reading grads with the right control deps + nccl_noops = [] + for j in range(num_grads)[::-1]: + deps = nccl_noops + [ + dev_grad[j] for dev_grad in self.per_device_grads + ] + with tf.control_dependencies(deps): + nccl_noops = [tf.no_op()] + + # You must fetch this otherwise the NCCL allreduce will hang + self.nccl_control_out = tf.group(*nccl_noops) + + if plasma_op: + store_socket = ( + ray.worker.global_worker.plasma_client.store_socket_name) + manager_socket = ( + ray.worker.global_worker.plasma_client.manager_socket_name) + if not plasma.tf_plasma_op: + plasma.build_plasma_tensorflow_op() + + # For fetching grads -> plasma + self.plasma_in_grads = [] + self.plasma_in_grads_oids = [ + tf.placeholder(shape=[], dtype=tf.string, name="in_grad_oids") + for _ in range(num_grads) + ] + for j in range(num_grads): + grad = self.per_device_grads[0][j] + with tf.device(self.models[0].loss.device): + plasma_grad = plasma.tf_plasma_op.tensor_to_plasma( + [grad], + self.plasma_in_grads_oids[j], + plasma_store_socket_name=store_socket, + plasma_manager_socket_name=manager_socket) + self.plasma_in_grads.append(plasma_grad) + + # For applying grads <- plasma + unpacked_gv = [] + self.plasma_out_grads_oids = [ + tf.placeholder( + shape=[], dtype=tf.string, name="grad_out_oids") + for _ in range(num_grads) + ] + packed_plasma_grads = [] + for j in range(num_grads): + with tf.device(self.plasma_in_grads[j].device): + with tf.control_dependencies([self.plasma_in_grads[j]]): + grad_ph = plasma.tf_plasma_op.plasma_to_tensor( + self.plasma_out_grads_oids[j], + dtype=tf.float32, + plasma_store_socket_name=store_socket, + plasma_manager_socket_name=manager_socket) + grad_ph = tf.reshape(grad_ph, + self.packed_grads_and_vars[0][j][0].shape) + logger.debug("Packed tensor {}".format(grad_ph)) + packed_plasma_grads.append(grad_ph) + for i in range(num_devices): + per_device = [] + for j, (g, v) in enumerate(self.packed_grads_and_vars[i]): + grad_ph = packed_plasma_grads[j] + per_device.append((grad_ph, v)) + unpacked_gv.append(per_device) + + if max_bytes: + unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals) + + elif max_bytes: + unpacked_gv = unpack_small_tensors(self.packed_grads_and_vars, + packing_vals) + else: + unpacked_gv = self.packed_grads_and_vars + + # Same shape as packed_grads_and_vars + assert len(unpacked_gv) == num_devices + assert len(unpacked_gv[0][0]) == 2 + + apply_ops = [] + to_apply = unpacked_gv[0] + for ix, m in enumerate(self.models): + apply_ops.append( + m.optimizer.apply_gradients( + [(g, v) + for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) + self.apply_op = tf.group(*apply_ops) + init_op = tf.group(tf.global_variables_initializer(), + tf.local_variables_initializer()) + self.sess.run(init_op) + + def _grad_feed_dict(self): + # Aggregate feed dicts for each model on this worker. + feed_dict = {} + for model in self.models: + feed_dict.update(model.get_feed_dict()) + return feed_dict + + def foreach_model(self, fn): + with self.sess.as_default(): + return [fn(m) for m in self.models] + + def foreach_worker(self, fn): + with self.sess.as_default(): + return fn(self) + + def for_model(self, fn): + with self.sess.as_default(): + return fn(self.models[0]) + + def compute_gradients(self): + start = time.time() + feed_dict = self._grad_feed_dict() + # We only need to fetch the first per_device_grad, since they are + # averaged across all devices by allreduce. + fetches = self.sess.run( + [ + self.models[0].loss, self.per_device_grads[0], + self.nccl_control_out + ], + feed_dict=feed_dict) + logger.debug( + "Compute grad interior time {}".format(time.time() - start)) + return fetches + + def apply_gradients(self, avg_grads): + start = time.time() + result = { + g: avg_grads[i] + for (i, g) in enumerate(self.per_device_grads[0]) + } + self.sess.run(self.apply_op, feed_dict=result) + logger.debug("Apply grad interior time {}".format(time.time() - start)) + + def compute_apply(self): + fetches = run_timeline( + self.sess, + [self.models[0].loss, self.apply_op, self.nccl_control_out], + feed_dict=self._grad_feed_dict(), + name="compute_apply") + return fetches[0] + + def ps_compute_apply(self, + out_grad_shard_oids, + agg_grad_shard_oids, + tl_name="ps_compute_apply", + write_timeline=False): + feed_dict = self._grad_feed_dict() + feed_dict.update( + dict(zip(self.plasma_in_grads_oids, out_grad_shard_oids))) + feed_dict.update( + dict(zip(self.plasma_out_grads_oids, agg_grad_shard_oids))) + fetch(agg_grad_shard_oids) + fetches = run_timeline( + self.sess, [ + self.models[0].loss, self.plasma_in_grads, self.apply_op, + self.nccl_control_out + ], + feed_dict=feed_dict, + write_timeline=write_timeline) + return fetches[0] + + def num_grad_shards(self): + return self.num_grads + + def shard_shapes(self): + main_gv = self.packed_grads_and_vars[0] + return [g.shape for g, _ in main_gv] + + def ip(self): + return ray.services.get_node_ip_address() + + def warmup(self): + warmup() diff --git a/python/ray/experimental/sgd/test_sgd.py b/python/ray/experimental/sgd/test_sgd.py index d6369a4e00011..79e00b2656ba7 100755 --- a/python/ray/experimental/sgd/test_sgd.py +++ b/python/ray/experimental/sgd/test_sgd.py @@ -1,32 +1,67 @@ +#!/usr/bin/env python + from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse +import time import ray from ray.experimental.sgd.tfbench.test_model import TFBenchModel from ray.experimental.sgd.sgd import DistributedSGD parser = argparse.ArgumentParser() +parser.add_argument("--redis-address", default=None, type=str) +parser.add_argument("--num-iters", default=10, type=int) +parser.add_argument("--batch-size", default=1, type=int) +parser.add_argument("--num-workers", default=2, type=int) +parser.add_argument("--grad-shard-bytes", default=10000000, type=int) +parser.add_argument("--devices-per-worker", default=2, type=int) +parser.add_argument("--stats-interval", default=10, type=int) +parser.add_argument("--all-reduce-alg", default="simple", type=str) +parser.add_argument("--object-store-memory", default=None, type=int) +parser.add_argument( + "--warmup", action="store_true", help="Warm up object store before start.") parser.add_argument( - "--num-iters", default=100, type=int, help="Number of iterations to run") + "--strategy", default="ps", type=str, help="One of 'simple' or 'ps'") +parser.add_argument( + "--gpu", action="store_true", help="Use GPUs for optimization") if __name__ == "__main__": - ray.init() - args, _ = parser.parse_known_args() + ray.init( + redis_address=args.redis_address, + object_store_memory=args.object_store_memory) model_creator = ( - lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True)) + lambda worker_idx, device_idx: TFBenchModel( + batch=args.batch_size, use_cpus=not args.gpu)) sgd = DistributedSGD( model_creator, - num_workers=2, - devices_per_worker=2, - use_cpus=True, - use_plasma_op=False) - - for _ in range(args.num_iters): - loss = sgd.step() - print("Current loss", loss) + num_workers=args.num_workers, + devices_per_worker=args.devices_per_worker, + gpu=args.gpu, + strategy=args.strategy, + grad_shard_bytes=args.grad_shard_bytes, + all_reduce_alg=args.all_reduce_alg) + + if args.warmup: + sgd.warmup() + + t = [] + + for i in range(args.num_iters): + start = time.time() + fetch_stats = i % args.stats_interval == 0 + print("== Step {} ==".format(i)) + stats = sgd.step(fetch_stats=fetch_stats) + ips = ((args.batch_size * args.num_workers * args.devices_per_worker) / + (time.time() - start)) + print("Iteration time", time.time() - start, "Images per second", ips) + t.append(ips) + if fetch_stats: + print("Current loss", stats) + + print("Peak throughput", max(sum(t[i:i + 5]) / 5 for i in range(len(t)))) diff --git a/python/ray/experimental/sgd/tfbench/test_model.py b/python/ray/experimental/sgd/tfbench/test_model.py index 0dd48607ef0a6..d866668f810d5 100644 --- a/python/ray/experimental/sgd/tfbench/test_model.py +++ b/python/ray/experimental/sgd/tfbench/test_model.py @@ -5,13 +5,14 @@ import tensorflow as tf from tfbench import model_config +from ray.experimental.sgd.model import Model class MockDataset(): name = "synthetic" -class TFBenchModel(object): +class TFBenchModel(Model): def __init__(self, batch=64, use_cpus=False): image_shape = [batch, 224, 224, 3] labels_shape = [batch] @@ -25,20 +26,22 @@ def __init__(self, batch=64, use_cpus=False): name='synthetic_images') # Minor hack to avoid H2D copy when using synthetic data - self.inputs = tf.contrib.framework.local_variable( + inputs = tf.contrib.framework.local_variable( images, name='gpu_cached_images') - self.labels = tf.random_uniform( + labels = tf.random_uniform( labels_shape, minval=0, maxval=999, dtype=tf.int32, name='synthetic_labels') - self.model = model_config.get_model_config("resnet101", MockDataset()) - logits, aux = self.model.build_network( - self.inputs, data_format=use_cpus and "NHWC" or "NCHW") + model = model_config.get_model_config("resnet101", MockDataset()) + logits, aux = model.build_network( + inputs, data_format=use_cpus and "NHWC" or "NCHW") loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=self.labels) + logits=logits, labels=labels) + + # Implement model interface self.loss = tf.reduce_mean(loss, name='xentropy-loss') self.optimizer = tf.train.GradientDescentOptimizer(1e-6) diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index ca72bb5e9ef43..c8df01cb35b25 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -4,6 +4,7 @@ import json import logging +import numpy as np import os import time import tensorflow as tf @@ -13,16 +14,29 @@ logger = logging.getLogger(__name__) +def warmup(): + logger.info("Warming up object store") + zeros = np.zeros(int(100e6 / 8), dtype=np.float64) + start = time.time() + for _ in range(10): + ray.put(zeros) + logger.info("Initial latency for 100MB put {}".format( + (time.time() - start) / 10)) + for _ in range(5): + for _ in range(100): + ray.put(zeros) + start = time.time() + for _ in range(10): + ray.put(zeros) + logger.info("Warmed up latency for 100MB put {}".format( + (time.time() - start) / 10)) + + def fetch(oids): - if ray.global_state.use_raylet: - local_sched_client = ray.worker.global_worker.local_scheduler_client - for o in oids: - ray_obj_id = ray.ObjectID(o) - local_sched_client.reconstruct_objects([ray_obj_id], True) - else: - for o in oids: - plasma_id = ray.pyarrow.plasma.ObjectID(o) - ray.worker.global_worker.plasma_client.fetch([plasma_id]) + local_sched_client = ray.worker.global_worker.local_scheduler_client + for o in oids: + ray_obj_id = ray.ObjectID(o) + local_sched_client.fetch_or_reconstruct([ray_obj_id], True) def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index d91165637b609..d97cc274f76d6 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -2,12 +2,8 @@ from __future__ import division from __future__ import print_function -import copy from collections import defaultdict -import heapq import json -import numbers -import os import redis import sys import time @@ -18,25 +14,6 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) -# This mapping from integer to task state string must be kept up-to-date with -# the scheduling_state enum in task.h. -TASK_STATUS_WAITING = 1 -TASK_STATUS_SCHEDULED = 2 -TASK_STATUS_QUEUED = 4 -TASK_STATUS_RUNNING = 8 -TASK_STATUS_DONE = 16 -TASK_STATUS_LOST = 32 -TASK_STATUS_RECONSTRUCTING = 64 -TASK_STATUS_MAPPING = { - TASK_STATUS_WAITING: "WAITING", - TASK_STATUS_SCHEDULED: "SCHEDULED", - TASK_STATUS_QUEUED: "QUEUED", - TASK_STATUS_RUNNING: "RUNNING", - TASK_STATUS_DONE: "DONE", - TASK_STATUS_LOST: "LOST", - TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING", -} - class GlobalState(object): """A class used to interface with the Ray control state. @@ -47,7 +24,6 @@ class GlobalState(object): Attributes: redis_client: The Redis client used to query the primary redis server. redis_clients: Redis clients for each of the Redis shards. - use_raylet: True if we are using the raylet code path. """ def __init__(self): @@ -57,8 +33,6 @@ def __init__(self): self.redis_client = None # Clients for the redis shards, storing the object table & task table. self.redis_clients = None - # True if we are using the raylet code path and false otherwise. - self.use_raylet = None def _check_connected(self): """Check that the object has been initialized before it is used. @@ -78,6 +52,7 @@ def _check_connected(self): def _initialize_global_state(self, redis_ip_address, redis_port, + redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. @@ -89,9 +64,10 @@ def _initialize_global_state(self, redis_ip_address: The IP address of the node that the Redis server lives on. redis_port: The port that the Redis server is listening on. + redis_password: The password of the redis server. """ self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) start_time = time.time() @@ -128,22 +104,15 @@ def _initialize_global_state(self, "ip_address_ports = {}".format( num_redis_shards, ip_address_ports)) - use_raylet = self.redis_client.get("UseRaylet") - if use_raylet is not None: - self.use_raylet = int(use_raylet) == 1 - elif os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - print("Detected environment variable 'RAY_USE_XRAY'.") - self.use_raylet = True - else: - self.use_raylet = False - # Get the rest of the information. self.redis_clients = [] for ip_address_port in ip_address_ports: shard_address, shard_port = ip_address_port.split(b":") self.redis_clients.append( - redis.StrictRedis(host=shard_address, port=shard_port)) + redis.StrictRedis( + host=shard_address, + port=shard_port, + password=redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. @@ -188,53 +157,29 @@ def _object_table(self, object_id): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. - if not self.use_raylet: - # Use the non-raylet code path. - object_locations = self._execute_command( - object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id()) - if object_locations is not None: - manager_ids = [ - binary_to_hex(manager_id) - for manager_id in object_locations - ] - else: - manager_ids = None - - result_table_response = self._execute_command( - object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) - result_table_message = ( - ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0)) - - result = { - "ManagerIDs": manager_ids, - "TaskID": binary_to_hex(result_table_message.TaskId()), - "IsPut": bool(result_table_message.IsPut()), - "DataSize": result_table_message.DataSize(), - "Hash": binary_to_hex(result_table_message.Hash()) - } + message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.OBJECT, "", + object_id.id()) + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) - else: - # Use the raylet code path. - message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, - "", object_id.id()) - result = [] - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - for i in range(gcs_entry.EntriesLength()): - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(i), 0) - object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), - "IsEviction": entry.IsEviction(), - "NumEvictions": entry.NumEvictions() - } - result.append(object_info) + assert gcs_entry.EntriesLength() > 0 - return result + entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( + gcs_entry.Entries(0), 0) + + object_info = { + "DataSize": entry.ObjectSize(), + "Manager": entry.Manager(), + "IsEviction": [entry.IsEviction()], + } + + for i in range(1, gcs_entry.EntriesLength()): + entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( + gcs_entry.Entries(i), 0) + object_info["IsEviction"].append(entry.IsEviction()) + + return object_info def object_table(self, object_id=None): """Fetch and parse the object table info for one or more object IDs. @@ -252,25 +197,12 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - if not self.use_raylet: - object_info_keys = self._keys( - ray.gcs_utils.OBJECT_INFO_PREFIX + "*") - object_location_keys = self._keys( - ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*") - object_ids_binary = set([ - key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):] - for key in object_info_keys - ] + [ - key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):] - for key in object_location_keys - ]) - else: - object_keys = self._keys( - ray.gcs_utils.TablePrefix_OBJECT_string + "*") - object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] - for key in object_keys - } + object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + + "*") + object_ids_binary = { + key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + for key in object_keys + } results = {} for object_id_binary in object_ids_binary: @@ -287,112 +219,49 @@ def _task_table(self, task_id): Returns: A dictionary with information about the task ID in question. - TASK_STATUS_MAPPING should be used to parse the "State" field - into a human-readable string. """ - if not self.use_raylet: - # Use the non-raylet code path. - task_table_response = self._execute_command( - task_id, "RAY.TASK_TABLE_GET", task_id.id()) - if task_table_response is None: - raise Exception("There is no entry for task ID {} in the task " - "table.".format(binary_to_hex(task_id.id()))) - task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - task_table_response, 0) - task_spec = task_table_message.TaskSpec() - task_spec = ray.local_scheduler.task_from_string(task_spec) - - task_spec_info = { - "DriverID": binary_to_hex(task_spec.driver_id().id()), - "TaskID": binary_to_hex(task_spec.task_id().id()), - "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), - "ParentCounter": task_spec.parent_counter(), - "ActorID": binary_to_hex(task_spec.actor_id().id()), - "ActorCreationID": binary_to_hex( - task_spec.actor_creation_id().id()), - "ActorCreationDummyObjectID": binary_to_hex( - task_spec.actor_creation_dummy_object_id().id()), - "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), - "Args": task_spec.arguments(), - "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() - } - - execution_dependencies_message = ( - ray.gcs_utils.TaskExecutionDependencies. - GetRootAsTaskExecutionDependencies( - task_table_message.ExecutionDependencies(), 0)) - execution_dependencies = [ - ray.ObjectID( - execution_dependencies_message.ExecutionDependencies(i)) - for i in range(execution_dependencies_message. - ExecutionDependenciesLength()) - ] - - # TODO(rkn): The return fields ExecutionDependenciesString and - # ExecutionDependencies are redundant, so we should remove - # ExecutionDependencies. However, it is currently used in - # monitor.py. - - return { - "State": task_table_message.State(), - "LocalSchedulerID": binary_to_hex( - task_table_message.LocalSchedulerId()), - "ExecutionDependenciesString": task_table_message. - ExecutionDependencies(), - "ExecutionDependencies": execution_dependencies, - "SpillbackCount": task_table_message.SpillbackCount(), - "TaskSpec": task_spec_info - } - - else: - # Use the raylet code path. - message = self._execute_command( - task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.id()) - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - info = [] - for i in range(gcs_entries.EntriesLength()): - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(i), 0) - - execution_spec = task_table_message.TaskExecutionSpec() - task_spec = task_table_message.TaskSpecification() - task_spec = ray.local_scheduler.task_from_string(task_spec) - task_spec_info = { - "DriverID": binary_to_hex(task_spec.driver_id().id()), - "TaskID": binary_to_hex(task_spec.task_id().id()), - "ParentTaskID": binary_to_hex( - task_spec.parent_task_id().id()), - "ParentCounter": task_spec.parent_counter(), - "ActorID": binary_to_hex(task_spec.actor_id().id()), - "ActorCreationID": binary_to_hex( - task_spec.actor_creation_id().id()), - "ActorCreationDummyObjectID": binary_to_hex( - task_spec.actor_creation_dummy_object_id().id()), - "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), - "Args": task_spec.arguments(), - "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() - } + message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + "", task_id.id()) + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) - info.append({ - "ExecutionSpec": { - "Dependencies": [ - execution_spec.Dependencies(i) - for i in range(execution_spec.DependenciesLength()) - ], - "LastTimestamp": execution_spec.LastTimestamp(), - "NumForwards": execution_spec.NumForwards() - }, - "TaskSpec": task_spec_info - }) + assert gcs_entries.EntriesLength() == 1 + + task_table_message = ray.gcs_utils.Task.GetRootAsTask( + gcs_entries.Entries(0), 0) + + execution_spec = task_table_message.TaskExecutionSpec() + task_spec = task_table_message.TaskSpecification() + task_spec = ray.raylet.task_from_string(task_spec) + task_spec_info = { + "DriverID": binary_to_hex(task_spec.driver_id().id()), + "TaskID": binary_to_hex(task_spec.task_id().id()), + "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), + "ParentCounter": task_spec.parent_counter(), + "ActorID": binary_to_hex(task_spec.actor_id().id()), + "ActorCreationID": binary_to_hex( + task_spec.actor_creation_id().id()), + "ActorCreationDummyObjectID": binary_to_hex( + task_spec.actor_creation_dummy_object_id().id()), + "ActorCounter": task_spec.actor_counter(), + "FunctionID": binary_to_hex(task_spec.function_id().id()), + "Args": task_spec.arguments(), + "ReturnObjectIDs": task_spec.returns(), + "RequiredResources": task_spec.required_resources() + } - return info + return { + "ExecutionSpec": { + "Dependencies": [ + execution_spec.Dependencies(i) + for i in range(execution_spec.DependenciesLength()) + ], + "LastTimestamp": execution_spec.LastTimestamp(), + "NumForwards": execution_spec.NumForwards() + }, + "TaskSpec": task_spec_info + } def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. @@ -409,19 +278,12 @@ def task_table(self, task_id=None): task_id = ray.ObjectID(hex_to_binary(task_id)) return self._task_table(task_id) else: - if not self.use_raylet: - task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*") - task_ids_binary = [ - key[len(ray.gcs_utils.TASK_PREFIX):] - for key in task_table_keys - ] - else: - task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") - task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] - for key in task_table_keys - ] + task_table_keys = self._keys( + ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + task_ids_binary = [ + key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + for key in task_table_keys + ] results = {} for task_id_binary in task_ids_binary: @@ -457,95 +319,54 @@ def client_table(self): Information about the Ray clients in the cluster. """ self._check_connected() - if not self.use_raylet: - db_client_keys = self.redis_client.keys( - ray.gcs_utils.DB_CLIENT_PREFIX + "*") - node_info = {} - for key in db_client_keys: - client_info = self.redis_client.hgetall(key) - node_ip_address = decode(client_info[b"node_ip_address"]) - if node_ip_address not in node_info: - node_info[node_ip_address] = [] - client_info_parsed = {} - assert b"client_type" in client_info - assert b"deleted" in client_info - assert b"ray_client_id" in client_info - for field, value in client_info.items(): - if field == b"node_ip_address": - pass - elif field == b"client_type": - client_info_parsed["ClientType"] = decode(value) - elif field == b"deleted": - client_info_parsed["Deleted"] = bool( - int(decode(value))) - elif field == b"ray_client_id": - client_info_parsed["DBClientID"] = binary_to_hex(value) - elif field == b"manager_address": - client_info_parsed["AuxAddress"] = decode(value) - elif field == b"local_scheduler_socket_name": - client_info_parsed["LocalSchedulerSocketName"] = ( - decode(value)) - elif client_info[b"client_type"] == b"local_scheduler": - # The remaining fields are resource types. - client_info_parsed[decode(field)] = float( - decode(value)) - else: - client_info_parsed[decode(field)] = decode(value) - - node_info[node_ip_address].append(client_info_parsed) - - return node_info - else: - # This is the raylet code path. - NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" - message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", - NIL_CLIENT_ID) - - # Handle the case where no clients are returned. This should only - # occur potentially immediately after the cluster is started. - if message is None: - return [] - - node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - # Since GCS entries are append-only, we override so that - # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = ( - ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) - - resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) - } - client_id = ray.utils.binary_to_hex(client.ClientId()) - - # If this client is being removed, then it must - # have previously been inserted, and - # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode(client.NodeManagerAddress()), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName()), - "RayletSocketName": decode(client.RayletSocketName()), - "Resources": resources - } - return list(node_info.values()) + NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", + NIL_CLIENT_ID) + + # Handle the case where no clients are returned. This should only + # occur potentially immediately after the cluster is started. + if message is None: + return [] + + node_info = {} + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + + # Since GCS entries are append-only, we override so that + # only the latest entries are kept. + for i in range(gcs_entry.EntriesLength()): + client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0)) + + resources = { + decode(client.ResourcesTotalLabel(i)): + client.ResourcesTotalCapacity(i) + for i in range(client.ResourcesTotalLabelLength()) + } + client_id = ray.utils.binary_to_hex(client.ClientId()) + + # If this client is being removed, then it must + # have previously been inserted, and + # it cannot have previously been removed. + if not client.IsInsertion(): + assert client_id in node_info, "Client removed not found!" + assert node_info[client_id]["IsInsertion"], ( + "Unexpected duplicate removal of client.") + + node_info[client_id] = { + "ClientID": client_id, + "IsInsertion": client.IsInsertion(), + "NodeManagerAddress": decode(client.NodeManagerAddress()), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName()), + "RayletSocketName": decode(client.RayletSocketName()), + "Resources": resources + } + return list(node_info.values()) def log_files(self): """Fetch and return a dictionary of log file names to outputs. @@ -575,129 +396,6 @@ def log_files(self): return ip_filename_file - def task_profiles(self, num_tasks, start=None, end=None, fwd=True): - """Fetch and return a list of task profiles. - - Args: - num_tasks: A limit on the number of tasks that task_profiles will - return. - start: The start point of the time window that is queried for - tasks. - end: The end point in time of the time window that is queried for - tasks. - fwd: If True, means that zrange will be used. If False, zrevrange. - This argument is only meaningful in conjunction with the - num_tasks argument. This controls whether the tasks returned - are the most recent or the least recent. - - Returns: - A tuple of two elements. The first element is a dictionary mapping - the task ID of a task to a list of the profiling information - for all of the executions of that task. The second element is a - list of profiling information for tasks where the events have - no task ID. - """ - task_info = {} - event_log_sets = self.redis_client.keys("event_log*") - - # The heap is used to maintain the set of x tasks that occurred the - # most recently across all of the workers, where x is defined as the - # function parameter num. The key is the start time of the "get_task" - # component of each task. Calling heappop will result in the task with - # the earliest "get_task_start" to be removed from the heap. - heap = [] - heapq.heapify(heap) - heap_size = 0 - - # Set up a param dict to pass the redis command - params = {"withscores": True} - if start is not None: - params["min"] = start - elif end is not None: - params["min"] = 0 - - if end is not None: - params["max"] = end - elif start is not None: - params["max"] = time.time() - - if start is None and end is None: - params["end"] = num_tasks - 1 - else: - params["num"] = num_tasks - params["start"] = 0 - - # Parse through event logs to determine task start and end points. - for event_log_set in event_log_sets: - if start is None and end is None: - if fwd: - event_list = self.redis_client.zrange( - event_log_set, **params) - else: - event_list = self.redis_client.zrevrange( - event_log_set, **params) - else: - if fwd: - event_list = self.redis_client.zrangebyscore( - event_log_set, **params) - else: - event_list = self.redis_client.zrevrangebyscore( - event_log_set, **params) - - for (event, score) in event_list: - event_dict = json.loads(decode(event)) - task_id = "" - for event in event_dict: - if "task_id" in event[3]: - task_id = event[3]["task_id"] - task_info[task_id] = {} - task_info[task_id]["score"] = score - # Add task to (min/max) heap by its start point. - # if fwd, we want to delete the largest elements, so -score - heapq.heappush(heap, (-score if fwd else score, task_id)) - heap_size += 1 - - for event in event_dict: - if event[1] == "get_task" and event[2] == 1: - task_info[task_id]["get_task_start"] = event[0] - if event[1] == "get_task" and event[2] == 2: - task_info[task_id]["get_task_end"] = event[0] - if (event[1] == "register_remote_function" - and event[2] == 1): - task_info[task_id]["import_remote_start"] = event[0] - if (event[1] == "register_remote_function" - and event[2] == 2): - task_info[task_id]["import_remote_end"] = event[0] - if (event[1] == "task:deserialize_arguments" - and event[2] == 1): - task_info[task_id]["get_arguments_start"] = event[0] - if (event[1] == "task:deserialize_arguments" - and event[2] == 2): - task_info[task_id]["get_arguments_end"] = event[0] - if event[1] == "task:execute" and event[2] == 1: - task_info[task_id]["execute_start"] = event[0] - if event[1] == "task:execute" and event[2] == 2: - task_info[task_id]["execute_end"] = event[0] - if event[1] == "task:store_outputs" and event[2] == 1: - task_info[task_id]["store_outputs_start"] = event[0] - if event[1] == "task:store_outputs" and event[2] == 2: - task_info[task_id]["store_outputs_end"] = event[0] - if "worker_id" in event[3]: - task_info[task_id]["worker_id"] = event[3]["worker_id"] - if "function_name" in event[3]: - task_info[task_id]["function_name"] = ( - event[3]["function_name"]) - - if heap_size > num_tasks: - min_task, task_id_hex = heapq.heappop(heap) - del task_info[task_id_hex] - heap_size -= 1 - - for key, info in task_info.items(): - self._add_missing_timestamps(info) - - return task_info - def _profile_table(self, component_id): """Get the profile events for a given component. @@ -748,10 +446,6 @@ def _profile_table(self, component_id): return profile_events def profile_table(self): - if not self.use_raylet: - raise Exception("This method is only supported in the raylet " - "code path.") - profile_table_keys = self._keys( ray.gcs_utils.TablePrefix_PROFILE_string + "*") component_identifiers_binary = [ @@ -765,24 +459,78 @@ def profile_table(self): for component_id in component_identifiers_binary } - def chrome_tracing_dump(self, - include_task_data=False, - filename=None, - open_browser=False): + def _seconds_to_microseconds(self, time_in_seconds): + """A helper function for converting seconds to microseconds.""" + time_in_microseconds = 10**6 * time_in_seconds + return time_in_microseconds + + # Colors are specified at + # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 + _default_color_mapping = defaultdict( + lambda: "generic_work", { + "worker_idle": "cq_build_abandoned", + "task": "rail_response", + "task:deserialize_arguments": "rail_load", + "task:execute": "rail_animation", + "task:store_outputs": "rail_idle", + "wait_for_function": "detailed_memory_dump", + "ray.get": "good", + "ray.put": "terrible", + "ray.wait": "vsync_highlight_color", + "submit_task": "background_memory_dump", + "fetch_and_run_function": "detailed_memory_dump", + "register_remote_function": "detailed_memory_dump", + }) + + # These colors are for use in Chrome tracing. + _chrome_tracing_colors = [ + "thread_state_uninterruptible", + "thread_state_iowait", + "thread_state_running", + "thread_state_runnable", + "thread_state_sleeping", + "thread_state_unknown", + "background_memory_dump", + "light_memory_dump", + "detailed_memory_dump", + "vsync_highlight_color", + "generic_work", + "good", + "bad", + "terrible", + # "black", + # "grey", + # "white", + "yellow", + "olive", + "rail_response", + "rail_animation", + "rail_idle", + "rail_load", + "startup", + "heap_dump_stack_frame", + "heap_dump_object_type", + "heap_dump_child_node_arrow", + "cq_build_running", + "cq_build_passed", + "cq_build_failed", + "cq_build_abandoned", + "cq_build_attempt_runnig", + "cq_build_attempt_passed", + "cq_build_attempt_failed", + ] + + def chrome_tracing_dump(self, filename=None): """Return a list of profiling events that can viewed as a timeline. To view this information as a timeline, simply dump it as a json file - using json.dumps, and then load go to chrome://tracing in the Chrome - web browser and load the dumped file. Make sure to enable "Flow events" - in the "View Options" menu. + by passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + Make sure to enable "Flow events" in the "View Options" menu. Args: - include_task_data: If true, we will include more task metadata such - as the task specifications in the json. filename: If a filename is provided, the timeline is dumped to that file. - open_browser: If true, we will attempt to automatically open the - timeline visualization in Chrome. Returns: If filename is not provided, this returns a list of profiling @@ -793,38 +541,15 @@ def chrome_tracing_dump(self, # TODO(rkn): This should support viewing just a window of time or a # limited number of events. - if include_task_data: - raise NotImplementedError("This flag has not been implented yet.") - - if open_browser: - raise NotImplementedError("This flag has not been implented yet.") - profile_table = self.profile_table() all_events = [] - # Colors are specified at - # https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501 - default_color_mapping = defaultdict( - lambda: "generic_work", { - "get_task": "cq_build_abandoned", - "task": "rail_response", - "task:deserialize_arguments": "rail_load", - "task:execute": "rail_animation", - "task:store_outputs": "rail_idle", - "wait_for_function": "detailed_memory_dump", - "ray.get": "good", - "ray.put": "terrible", - "ray.wait": "vsync_highlight_color", - "submit_task": "background_memory_dump", - "fetch_and_run_function": "detailed_memory_dump", - "register_remote_function": "detailed_memory_dump", - }) - - def seconds_to_microseconds(time_in_seconds): - time_in_microseconds = 10**6 * time_in_seconds - return time_in_microseconds - for component_id_hex, component_events in profile_table.items(): + # Only consider workers and drivers. + component_type = component_events[0]["component_type"] + if component_type not in ["worker", "driver"]: + continue + for event in component_events: new_event = { # The category of the event. @@ -838,14 +563,14 @@ def seconds_to_microseconds(time_in_seconds): "tid": event["component_type"] + ":" + event["component_id"], # The start time in microseconds. - "ts": seconds_to_microseconds(event["start_time"]), + "ts": self._seconds_to_microseconds(event["start_time"]), # The duration in microseconds. - "dur": seconds_to_microseconds(event["end_time"] - - event["start_time"]), + "dur": self._seconds_to_microseconds(event["end_time"] - + event["start_time"]), # What is this? "ph": "X", # This is the name of the color to display the box in. - "cname": default_color_mapping[event["event_type"]], + "cname": self._default_color_mapping[event["event_type"]], # The extra user-defined data. "args": event["extra_data"], } @@ -865,357 +590,96 @@ def seconds_to_microseconds(time_in_seconds): else: return all_events - def dump_catapult_trace(self, - path, - task_info, - breakdowns=True, - task_dep=True, - obj_dep=True): - """Dump task profiling information to a file. + def chrome_tracing_object_transfer_dump(self, filename=None): + """Return a list of transfer events that can viewed as a timeline. - This information can be viewed as a timeline of profiling information - by going to chrome://tracing in the chrome web browser and loading the - appropriate file. + To view this information as a timeline, simply dump it as a json file + by passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + Make sure to enable "Flow events" in the "View Options" menu. Args: - path: The filepath to dump the profiling information to. - task_info: The task info to use to generate the trace. Should be - the output of ray.global_state.task_profiles(). - breakdowns: Boolean indicating whether to break down the tasks into - more fine-grained segments. - task_dep: Boolean indicating whether or not task submission edges - should be included in the trace. - obj_dep: Boolean indicating whether or not object dependency edges - should be included in the trace. - """ - workers = self.workers() - - task_table = {} - # TODO(ekl) reduce the number of RPCs here with MGET - for task_id, _ in task_info.items(): - try: - # TODO (hme): do something to correct slider here, - # slider should be correct to begin with, though. - task_table[task_id] = self.task_table(task_id) - task_table[task_id]["TaskSpec"]["Args"] = [ - repr(arg) - for arg in task_table[task_id]["TaskSpec"]["Args"] - ] - except Exception as e: - print("Could not find task {}".format(task_id)) - - # filter out tasks not in task_table - task_info = {k: v for k, v in task_info.items() if k in task_table} - - start_time = None - for info in task_info.values(): - task_start = min(self._get_times(info)) - if not start_time or task_start < start_time: - start_time = task_start - - def micros(ts): - return int(1e6 * ts) - - def micros_rel(ts): - return micros(ts - start_time) - - seen_obj = {} - - full_trace = [] - for task_id, info in task_info.items(): - worker = workers[info["worker_id"]] - task_t_info = task_table[task_id] - - # The total_info dictionary is what is displayed when selecting a - # task in the timeline. We copy the task spec so that we don't - # modify it in place since we will use the original values later. - total_info = copy.copy(task_table[task_id]["TaskSpec"]) - total_info["Args"] = [ - oid.hex() if isinstance(oid, ray.ObjectID) else oid - for oid in task_t_info["TaskSpec"]["Args"] - ] - total_info["ReturnObjectIDs"] = [ - oid.hex() for oid in task_t_info["TaskSpec"]["ReturnObjectIDs"] - ] - total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"] - total_info["get_arguments"] = ( - info["get_arguments_end"] - info["get_arguments_start"]) - total_info["execute"] = ( - info["execute_end"] - info["execute_start"]) - total_info["store_outputs"] = ( - info["store_outputs_end"] - info["store_outputs_start"]) - total_info["function_name"] = info["function_name"] - total_info["worker_id"] = info["worker_id"] - - parent_info = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) - worker = workers[info["worker_id"]] - # The catapult trace format documentation can be found here: - # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview # noqa: E501 - if breakdowns: - if "get_arguments_end" in info: - get_args_trace = { - "cat": "get_arguments", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "id": task_id, - "ts": micros_rel(info["get_arguments_start"]), - "ph": "X", - "name": info["function_name"] + ":get_arguments", - "args": total_info, - "dur": micros(info["get_arguments_end"] - - info["get_arguments_start"]), - "cname": "rail_idle" - } - full_trace.append(get_args_trace) - - if "store_outputs_end" in info: - outputs_trace = { - "cat": "store_outputs", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "id": task_id, - "ts": micros_rel(info["store_outputs_start"]), - "ph": "X", - "name": info["function_name"] + ":store_outputs", - "args": total_info, - "dur": micros(info["store_outputs_end"] - - info["store_outputs_start"]), - "cname": "thread_state_runnable" - } - full_trace.append(outputs_trace) - - if "execute_end" in info: - execute_trace = { - "cat": "execute", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "id": task_id, - "ts": micros_rel(info["execute_start"]), - "ph": "X", - "name": info["function_name"] + ":execute", - "args": total_info, - "dur": micros(info["execute_end"] - - info["execute_start"]), - "cname": "rail_animation" - } - full_trace.append(execute_trace) - - else: - if parent_info: - parent_worker = workers[parent_info["worker_id"]] - parent_times = self._get_times(parent_info) - parent_profile = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) - - _parent_id = parent_info["worker_id"] + str( - micros(min(parent_times))) - - parent = { - "cat": "submit_task", - "pid": "Node " + parent_worker["node_ip_address"], - "tid": parent_info["worker_id"], - "ts": micros_rel( - parent_profile - and parent_profile["get_arguments_start"] - or start_time), - "ph": "s", - "name": "SubmitTask", - "args": {}, - "id": _parent_id, - } - full_trace.append(parent) - - _id = info["worker_id"] + str(micros(min(parent_times))) - - task_trace = { - "cat": "submit_task", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "ts": micros_rel(info["get_arguments_start"]), - "ph": "f", - "name": "SubmitTask", - "args": {}, - "id": _id, - "bp": "e", - "cname": "olive" - } - full_trace.append(task_trace) - - task = { - "cat": "task", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "id": task_id, - "ts": micros_rel(info["get_arguments_start"]), - "ph": "X", - "name": info["function_name"], - "args": total_info, - "dur": micros(info["store_outputs_end"] - - info["get_arguments_start"]), - "cname": "thread_state_runnable" - } - full_trace.append(task) - - if task_dep: - if parent_info: - parent_worker = workers[parent_info["worker_id"]] - parent_times = self._get_times(parent_info) - parent_profile = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) - - _parent_id = parent_info["worker_id"] + str( - micros(min(parent_times))) - - parent = { - "cat": "submit_task", - "pid": "Node " + parent_worker["node_ip_address"], - "tid": parent_info["worker_id"], - "ts": micros_rel( - parent_profile - and parent_profile["get_arguments_start"] - or start_time), - "ph": "s", - "name": "SubmitTask", - "args": {}, - "id": _parent_id, - } - full_trace.append(parent) - - _id = info["worker_id"] + str(micros(min(parent_times))) - - task_trace = { - "cat": "submit_task", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "ts": micros_rel(info["get_arguments_start"]), - "ph": "f", - "name": "SubmitTask", - "args": {}, - "id": _id, - "bp": "e" - } - full_trace.append(task_trace) - - if obj_dep: - args = task_table[task_id]["TaskSpec"]["Args"] - for arg in args: - # Don't visualize arguments that are not object IDs. - if isinstance(arg, ray.ObjectID): - object_info = self._object_table(arg) - # Don't visualize objects that were created by calls to - # put. - if not object_info["IsPut"]: - if arg not in seen_obj: - seen_obj[arg] = 0 - seen_obj[arg] += 1 - owner_task = self._object_table(arg)["TaskID"] - if owner_task in task_info: - owner_worker = (workers[task_info[owner_task][ - "worker_id"]]) - # Adding/subtracting 2 to the time associated - # with the beginning/ending of the flow event - # is necessary to make the flow events show up - # reliably. When these times are exact, this is - # presumably an edge case, and catapult doesn't - # recognize that there is a duration event at - # that exact point in time that the flow event - # should be bound to. This issue is solved by - # adding the 2 ms to the start/end time of the - # flow event, which guarantees overlap with the - # duration event that it's associated with, and - # the flow event therefore always gets drawn. - owner = { - "cat": "obj_dependency", - "pid": ("Node " + - owner_worker["node_ip_address"]), - "tid": task_info[owner_task]["worker_id"], - "ts": micros_rel(task_info[owner_task] - ["store_outputs_end"]) - - 2, - "ph": "s", - "name": "ObjectDependency", - "args": {}, - "bp": "e", - "cname": "cq_build_attempt_failed", - "id": "obj" + str(arg) + str(seen_obj[arg]) - } - full_trace.append(owner) - - dependent = { - "cat": "obj_dependency", - "pid": "Node " + worker["node_ip_address"], - "tid": info["worker_id"], - "ts": micros_rel(info["get_arguments_start"]) + - 2, - "ph": "f", - "name": "ObjectDependency", - "args": {}, - "cname": "cq_build_attempt_failed", - "bp": "e", - "id": "obj" + str(arg) + str(seen_obj[arg]) - } - full_trace.append(dependent) - - print("Creating JSON {}/{}".format(len(full_trace), len(task_info))) - with open(path, "w") as outfile: - json.dump(full_trace, outfile) - - def _get_times(self, data): - """Extract the numerical times from a task profile. - - This is a helper method for dump_catapult_trace. + filename: If a filename is provided, the timeline is dumped to that + file. - Args: - data: This must be a value in the dictionary returned by the - task_profiles function. - """ - all_times = [] - all_times.append(data["acquire_lock_start"]) - all_times.append(data["acquire_lock_end"]) - all_times.append(data["get_arguments_start"]) - all_times.append(data["get_arguments_end"]) - all_times.append(data["execute_start"]) - all_times.append(data["execute_end"]) - all_times.append(data["store_outputs_start"]) - all_times.append(data["store_outputs_end"]) - return all_times - - def _add_missing_timestamps(self, info): - """Fills in any missing timestamp values in a task info. - - Task timestamps may be missing if the task fails or is partially - executed. + Returns: + If filename is not provided, this returns a list of profiling + events. Each profile event is a dictionary. """ + client_id_to_address = {} + for client_info in ray.global_state.client_table(): + client_id_to_address[client_info["ClientID"]] = "{}:{}".format( + client_info["NodeManagerAddress"], + client_info["ObjectManagerPort"]) - keys = [ - "acquire_lock_start", "acquire_lock_end", "get_arguments_start", - "get_arguments_end", "execute_start", "execute_end", - "store_outputs_start", "store_outputs_end" - ] + all_events = [] + + for key, items in self.profile_table().items(): + # Only consider object manager events. + if items[0]["component_type"] != "object_manager": + continue - latest_timestamp = 0 - for key in keys: - cur = info.get(key, latest_timestamp) - info[key] = cur - latest_timestamp = cur + for event in items: + if event["event_type"] == "transfer_send": + object_id, remote_client_id, _, _ = event["extra_data"] - def local_schedulers(self): - """Get a list of live local schedulers. + elif event["event_type"] == "transfer_receive": + object_id, remote_client_id, _, _ = event["extra_data"] - Returns: - A list of the live local schedulers. - """ - if self.use_raylet: - raise Exception("The local_schedulers() method is deprecated.") - clients = self.client_table() - local_schedulers = [] - for ip_address, client_list in clients.items(): - for client in client_list: - if (client["ClientType"] == "local_scheduler" - and not client["Deleted"]): - local_schedulers.append(client) - return local_schedulers + elif event["event_type"] == "receive_pull_request": + object_id, remote_client_id = event["extra_data"] + + else: + assert False, "This should be unreachable." + + # Choose a color by reading the first couple of hex digits of + # the object ID as an integer and turning that into a color. + object_id_int = int(object_id[:2], 16) + color = self._chrome_tracing_colors[object_id_int % len( + self._chrome_tracing_colors)] + + new_event = { + # The category of the event. + "cat": event["event_type"], + # The string displayed on the event. + "name": event["event_type"], + # The identifier for the group of rows that the event + # appears in. + "pid": client_id_to_address[key], + # The identifier for the row that the event appears in. + "tid": client_id_to_address[remote_client_id], + # The start time in microseconds. + "ts": self._seconds_to_microseconds(event["start_time"]), + # The duration in microseconds. + "dur": self._seconds_to_microseconds(event["end_time"] - + event["start_time"]), + # What is this? + "ph": "X", + # This is the name of the color to display the box in. + "cname": color, + # The extra user-defined data. + "args": event["extra_data"], + } + all_events.append(new_event) + + # Add another box with a color indicating whether it was a send + # or a receive event. + if event["event_type"] == "transfer_send": + additional_event = new_event.copy() + additional_event["cname"] = "black" + all_events.append(additional_event) + elif event["event_type"] == "transfer_receive": + additional_event = new_event.copy() + additional_event["cname"] = "grey" + all_events.append(additional_event) + else: + pass + + if filename is not None: + with open(filename, "w") as outfile: + json.dump(all_events, outfile) + else: + return all_events def workers(self): """Get a dictionary mapping worker ID to worker information.""" @@ -1227,11 +691,7 @@ def workers(self): worker_id = binary_to_hex(worker_key[len("Workers:"):]) workers_data[worker_id] = { - "local_scheduler_socket": (decode( - worker_info[b"local_scheduler_socket"])), "node_ip_address": decode(worker_info[b"node_ip_address"]), - "plasma_manager_socket": decode( - worker_info[b"plasma_manager_socket"]), "plasma_store_socket": decode( worker_info[b"plasma_store_socket"]) } @@ -1291,28 +751,28 @@ def cluster_resources(self): resource in the cluster. """ resources = defaultdict(int) - if not self.use_raylet: - local_schedulers = self.local_schedulers() - - for local_scheduler in local_schedulers: - for key, value in local_scheduler.items(): - if key not in [ - "ClientType", "Deleted", "DBClientID", - "AuxAddress", "LocalSchedulerSocketName" - ]: - resources[key] += value - - else: - clients = self.client_table() - for client in clients: + clients = self.client_table() + for client in clients: + # Only count resources from live clients. + if client["IsInsertion"]: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) + def _live_client_ids(self): + """Returns a set of client IDs corresponding to clients still alive.""" + return { + client["ClientID"] + for client in self.client_table() if client["IsInsertion"] + } + def available_resources(self): """Get the current available cluster resources. + This is different from `cluster_resources` in that this will return + idle (available) resources rather than total resources. + Note that this information can grow stale as tasks start and finish. Returns: @@ -1321,97 +781,48 @@ def available_resources(self): """ available_resources_by_id = {} - if not self.use_raylet: - subscribe_client = self.redis_client.pubsub() - subscribe_client.subscribe( - ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) + subscribe_clients = [ + redis_client.pubsub(ignore_subscribe_messages=True) + for redis_client in self.redis_clients + ] + for subscribe_client in subscribe_clients: + subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) - local_scheduler_ids = { - local_scheduler["DBClientID"] - for local_scheduler in self.local_schedulers() - } + client_ids = self._live_client_ids() - while set(available_resources_by_id.keys()) != local_scheduler_ids: + while set(available_resources_by_id.keys()) != client_ids: + for subscribe_client in subscribe_clients: + # Parse client message raw_message = subscribe_client.get_message() - if raw_message is None: + if (raw_message is None or raw_message["channel"] != + ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - # Ignore subscribtion success message from Redis - # This is a long in python 2 and an int in python 3 - if isinstance(data, numbers.Number): - continue - message = (ray.gcs_utils.LocalSchedulerInfoMessage. - GetRootAsLocalSchedulerInfoMessage(data, 0)) - num_resources = message.DynamicResourcesLength() + gcs_entries = ( + ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + data, 0)) + heartbeat_data = gcs_entries.Entries(0) + message = (ray.gcs_utils.HeartbeatTableData. + GetRootAsHeartbeatTableData(heartbeat_data, 0)) + # Calculate available resources for this client + num_resources = message.ResourcesAvailableLabelLength() dynamic_resources = {} for i in range(num_resources): - dyn = message.DynamicResources(i) - resource_id = decode(dyn.Key()) - dynamic_resources[resource_id] = dyn.Value() + resource_id = decode(message.ResourcesAvailableLabel(i)) + dynamic_resources[resource_id] = ( + message.ResourcesAvailableCapacity(i)) - # Update available resources for this local scheduler - client_id = binary_to_hex(message.DbClientId()) + # Update available resources for this client + client_id = ray.utils.binary_to_hex(message.ClientId()) available_resources_by_id[client_id] = dynamic_resources - # Update local schedulers in cluster - local_scheduler_ids = { - local_scheduler["DBClientID"] - for local_scheduler in self.local_schedulers() - } - - # Remove disconnected local schedulers - for local_scheduler_id in available_resources_by_id.keys(): - if local_scheduler_id not in local_scheduler_ids: - del available_resources_by_id[local_scheduler_id] - else: - # Assumes the number of Redis clients does not change - subscribe_clients = [ - redis_client.pubsub(ignore_subscribe_messages=True) - for redis_client in self.redis_clients - ] - for subscribe_client in subscribe_clients: - subscribe_client.subscribe( - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) - - client_ids = {client["ClientID"] for client in self.client_table()} - - while set(available_resources_by_id.keys()) != client_ids: - for subscribe_client in subscribe_clients: - # Parse client message - raw_message = subscribe_client.get_message() - if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): - continue - data = raw_message["data"] - gcs_entries = ( - ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) - # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() - dynamic_resources = {} - for i in range(num_resources): - resource_id = decode( - message.ResourcesAvailableLabel(i)) - dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) - - # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) - available_resources_by_id[client_id] = dynamic_resources - - # Update clients in cluster - client_ids = { - client["ClientID"] - for client in self.client_table() - } + # Update clients in cluster + client_ids = self._live_client_ids() - # Remove disconnected clients - for client_id in available_resources_by_id.keys(): - if client_id not in client_ids: - del available_resources_by_id[client_id] + # Remove disconnected clients + for client_id in available_resources_by_id.keys(): + if client_id not in client_ids: + del available_resources_by_id[client_id] # Calculate total available resources total_available_resources = defaultdict(int) @@ -1464,10 +875,6 @@ def error_messages(self, job_id=None): A dictionary mapping job ID to a list of the error messages for that job. """ - if not self.use_raylet: - raise Exception("The error_messages method is only supported in " - "the raylet code path.") - if job_id is not None: return self._error_messages(job_id) diff --git a/python/ray/experimental/test/async_test.py b/python/ray/experimental/test/async_test.py new file mode 100644 index 0000000000000..bdf45f77e8282 --- /dev/null +++ b/python/ray/experimental/test/async_test.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import asyncio +import time + +import pytest + +import ray +from ray.experimental import async_api + + +@pytest.fixture +def init(): + ray.init(num_cpus=4) + async_api.init() + asyncio.get_event_loop().set_debug(False) + yield + async_api.shutdown() + ray.shutdown() + + +def gen_tasks(time_scale=0.1): + @ray.remote + def f(n): + time.sleep(n * time_scale) + return n + + tasks = [f.remote(i) for i in range(5)] + return tasks + + +def test_simple(init): + @ray.remote + def f(): + time.sleep(1) + return {"key1": ["value"]} + + future = async_api.as_future(f.remote()) + result = asyncio.get_event_loop().run_until_complete(future) + assert result["key1"] == ["value"] + + +def test_gather(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + results = loop.run_until_complete(asyncio.gather(*futures)) + assert all(a == b for a, b in zip(results, ray.get(tasks))) + + +def test_gather_benchmark(init): + @ray.remote + def f(n): + time.sleep(0.001 * n) + return 42 + + async def test_async(): + sum_time = 0. + for _ in range(50): + tasks = [f.remote(n) for n in range(20)] + start = time.time() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + await asyncio.gather(*futures) + sum_time += time.time() - start + return sum_time + + def baseline(): + sum_time = 0. + for _ in range(50): + tasks = [f.remote(n) for n in range(20)] + start = time.time() + ray.get(tasks) + sum_time += time.time() - start + return sum_time + + # warm up + baseline() + # async get + sum_time_1 = asyncio.get_event_loop().run_until_complete(test_async()) + # get + sum_time_2 = baseline() + + # Ensure the new implementation is not too slow. + assert sum_time_2 * 1.2 > sum_time_1 + + +def test_wait(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + results, _ = loop.run_until_complete(asyncio.wait(futures)) + assert set(results) == set(futures) + + +def test_wait_timeout(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks(10) + futures = [async_api.as_future(obj_id) for obj_id in tasks] + fut = asyncio.wait(futures, timeout=5) + results, _ = loop.run_until_complete(fut) + assert list(results)[0] == futures[0] + + +def test_gather_mixup(init): + loop = asyncio.get_event_loop() + + @ray.remote + def f(n): + time.sleep(n * 0.1) + return n + + async def g(n): + await asyncio.sleep(n * 0.1) + return n + + tasks = [ + async_api.as_future(f.remote(1)), + g(2), + async_api.as_future(f.remote(3)), + g(4) + ] + results = loop.run_until_complete(asyncio.gather(*tasks)) + assert results == [1, 2, 3, 4] + + +def test_wait_mixup(init): + loop = asyncio.get_event_loop() + + @ray.remote + def f(n): + time.sleep(n) + return n + + def g(n): + async def _g(_n): + await asyncio.sleep(_n) + return _n + + return asyncio.ensure_future(_g(n)) + + tasks = [ + async_api.as_future(f.remote(0.1)), + g(7), + async_api.as_future(f.remote(5)), + g(2) + ] + ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4)) + assert set(ready) == {tasks[0], tasks[-1]} diff --git a/python/ray/experimental/ui.py b/python/ray/experimental/ui.py index da4ee9e57c838..15a6fd05f839d 100644 --- a/python/ray/experimental/ui.py +++ b/python/ray/experimental/ui.py @@ -1,20 +1,23 @@ -import ipywidgets as widgets +import logging import numpy as np import os import pprint -import ray import shutil import tempfile import time +import ipywidgets as widgets from IPython.display import display, IFrame, clear_output +import ray + +logger = logging.getLogger(__name__) + + # Instances of this class maintains keep track of whether or not a # callback is currently executing. Since the execution of the callback # may trigger more calls to the callback, this is used to prevent infinite # recursions. - - class _EventRecursionContextManager(object): def __init__(self): self.should_recurse = True @@ -185,36 +188,6 @@ def update_wrapper(event): range_slider.value = (100 + int( 100 * float(num_tasks_box.value) / num_tasks), 100) - if not update: - return - - diff = largest - smallest - - # Low and high are used to scale the times that are - # queried to be relative to the absolute time. - low, high = map(lambda x: x / 100., range_slider.value) - - # Queries to task_profiles based on the slider and text - # box values. - # (Querying based on the % total amount of time.) - if breakdown_opt.value == total_time_value: - tasks = _truncated_task_profiles( - start=(smallest + diff * low), - end=(smallest + diff * high)) - - # (Querying based on % of total number of tasks that were - # run.) - elif breakdown_opt.value == total_tasks_value: - if range_slider.value[0] == 0: - tasks = _truncated_task_profiles( - num_tasks=(int(num_tasks * high)), fwd=True) - else: - tasks = _truncated_task_profiles( - num_tasks=(int(num_tasks * (high - low))), - fwd=False) - - update(smallest, largest, num_tasks, tasks) - # Get updated values from a slider or text box, and update the rest of # them accordingly. range_slider.observe(update_wrapper, names="value") @@ -268,20 +241,6 @@ def handle_submit(sender): MAX_TASKS_TO_VISUALIZE = 10000 -# Wrapper that enforces a limit on the number of tasks to visualize -def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True): - if num_tasks is None: - num_tasks = MAX_TASKS_TO_VISUALIZE - print("Warning: at most {} tasks will be fetched within this " - "time range.".format(MAX_TASKS_TO_VISUALIZE)) - elif num_tasks > MAX_TASKS_TO_VISUALIZE: - print("Warning: too many tasks to visualize, " - "fetching only the first {} of {}.".format( - MAX_TASKS_TO_VISUALIZE, num_tasks)) - num_tasks = MAX_TASKS_TO_VISUALIZE - return ray.global_state.task_profiles(num_tasks, start, end, fwd) - - # Helper function that guarantees unique and writeable temp files. # Prevents clashes in task trace files when multiple notebooks are running. def _get_temp_file_path(**kwargs): @@ -293,32 +252,43 @@ def _get_temp_file_path(**kwargs): def task_timeline(): - path_input = widgets.Button(description="View task timeline") + # Check that the trace viewer renderer file is present, and copy it to the + # current working directory if it is not present. + if not os.path.exists("trace_viewer_full.html"): + shutil.copy( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../core/src/catapult_files/trace_viewer_full.html"), + "trace_viewer_full.html") - breakdown_basic = "Basic" - breakdown_task = "Task Breakdowns" + trace_viewer_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../core/src/catapult_files/index.html") - breakdown_opt = widgets.Dropdown( - options=["Basic", "Task Breakdowns"], - value="Task Breakdowns", - disabled=False, - ) - obj_dep = widgets.Checkbox( - value=True, disabled=False, layout=widgets.Layout(width='20px')) - task_dep = widgets.Checkbox( - value=True, disabled=False, layout=widgets.Layout(width='20px')) - # Labels to bypass width limitation for descriptions. - label_tasks = widgets.Label( - value='Task submissions', layout=widgets.Layout(width='110px')) - label_objects = widgets.Label( - value='Object dependencies', layout=widgets.Layout(width='130px')) - label_options = widgets.Label( - value='View options:', layout=widgets.Layout(width='100px')) - start_box, end_box, range_slider, time_opt = get_sliders(False) - display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects])) - display(widgets.HBox([label_options, breakdown_opt])) - display(path_input) + html_file_path = _get_temp_file_path(suffix=".html") + json_file_path = _get_temp_file_path(suffix=".json") + + ray.global_state.chrome_tracing_dump(filename=json_file_path) + + with open(trace_viewer_path) as f: + data = f.read() + + # Replace the demo data path with our own + # https://github.com/catapult-project/catapult/blob/ + # 33a9271eb3cf5caf925293ec6a4b47c94f1ac968/tracing/bin/index.html#L107 + data = data.replace("../test_data/big_trace.json", json_file_path) + + with open(html_file_path, "w+") as f: + f.write(data) + # Display the task trace within the Jupyter notebook + clear_output(wait=True) + logger.info("To view fullscreen, open chrome://tracing in Google Chrome " + "and load `{}`".format(os.path.abspath(json_file_path))) + display(IFrame(html_file_path, 900, 800)) + + +def object_transfer_timeline(): # Check that the trace viewer renderer file is present, and copy it to the # current working directory if it is not present. if not os.path.exists("trace_viewer_full.html"): @@ -328,76 +298,32 @@ def task_timeline(): "../core/src/catapult_files/trace_viewer_full.html"), "trace_viewer_full.html") - def handle_submit(sender): - json_tmp = tempfile.mktemp() + ".json" - - # Determine whether task components should be displayed or not. - if breakdown_opt.value == breakdown_basic: - breakdown = False - elif breakdown_opt.value == breakdown_task: - breakdown = True - else: - raise ValueError("Unexpected breakdown value '{}'".format( - breakdown_opt.value)) - - low, high = map(lambda x: x / 100., range_slider.value) - - smallest, largest, num_tasks = ray.global_state._job_length() - diff = largest - smallest - - if time_opt.value == total_time_value: - tasks = _truncated_task_profiles( - start=smallest + diff * low, end=smallest + diff * high) - elif time_opt.value == total_tasks_value: - if range_slider.value[0] == 0: - tasks = _truncated_task_profiles( - num_tasks=int(num_tasks * high), fwd=True) - else: - tasks = _truncated_task_profiles( - num_tasks=int(num_tasks * (high - low)), fwd=False) - else: - raise ValueError("Unexpected time value '{}'".format( - time_opt.value)) - # Write trace to a JSON file - print("Collected profiles for {} tasks.".format(len(tasks))) - print("Dumping task profile data to {}, " - "this might take a while...".format(json_tmp)) - ray.global_state.dump_catapult_trace( - json_tmp, - tasks, - breakdowns=breakdown, - obj_dep=obj_dep.value, - task_dep=task_dep.value) - print("Opening html file in browser...") - - trace_viewer_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "../core/src/catapult_files/index.html") - - html_file_path = _get_temp_file_path(suffix=".html") - json_file_path = _get_temp_file_path(suffix=".json") - - print("Pointing to {} named {}".format(json_tmp, json_file_path)) - shutil.copy(json_tmp, json_file_path) - - with open(trace_viewer_path) as f: - data = f.read() - - # Replace the demo data path with our own - # https://github.com/catapult-project/catapult/blob/ - # 33a9271eb3cf5caf925293ec6a4b47c94f1ac968/tracing/bin/index.html#L107 - data = data.replace("../test_data/big_trace.json", json_file_path) - - with open(html_file_path, "w+") as f: - f.write(data) - - # Display the task trace within the Jupyter notebook - clear_output(wait=True) - print("To view fullscreen, open chrome://tracing in Google Chrome " - "and load `{}`".format(json_tmp)) - display(IFrame(html_file_path, 900, 800)) - - path_input.on_click(handle_submit) + trace_viewer_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../core/src/catapult_files/index.html") + + html_file_path = _get_temp_file_path(suffix=".html") + json_file_path = _get_temp_file_path(suffix=".json") + + ray.global_state.chrome_tracing_object_transfer_dump( + filename=json_file_path) + + with open(trace_viewer_path) as f: + data = f.read() + + # Replace the demo data path with our own + # https://github.com/catapult-project/catapult/blob/ + # 33a9271eb3cf5caf925293ec6a4b47c94f1ac968/tracing/bin/index.html#L107 + data = data.replace("../test_data/big_trace.json", json_file_path) + + with open(html_file_path, "w+") as f: + f.write(data) + + # Display the task trace within the Jupyter notebook + clear_output(wait=True) + logger.info("To view fullscreen, open chrome://tracing in Google Chrome " + "and load `{}`".format(os.path.abspath(json_file_path))) + display(IFrame(html_file_path, 900, 800)) def task_completion_time_distribution(): @@ -562,12 +488,7 @@ def cpu_usage(): output_notebook(resources=CDN) # Parse the client table to determine how many CPUs are available - num_cpus = 0 - client_table = ray.global_state.client_table() - for node_ip, client_list in client_table.items(): - for client in client_list: - if "CPU" in client: - num_cpus += client["CPU"] + num_cpus = ray.global_state.cluster_resources()["CPU"] # Update the plot based on the sliders def plot_utilization(): diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py new file mode 100644 index 0000000000000..72ec53651df76 --- /dev/null +++ b/python/ray/function_manager.py @@ -0,0 +1,495 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import inspect +import json +import sys +import time +import traceback +from collections import ( + namedtuple, + defaultdict, +) + +import ray +from ray import profiling +from ray import ray_constants +from ray import cloudpickle as pickle +from ray.utils import ( + is_cython, + is_function_or_method, + is_class_method, + check_oversized_pickle, + decode, + format_error_message, + push_error_to_driver, +) + +FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", + ["function", "function_name", "max_calls"]) +"""FunctionExecutionInfo: A named tuple storing remote function information.""" + + +class FunctionActorManager(object): + """A class used to export/load remote functions and actors. + + Attributes: + _worker: The associated worker that this manager related. + _functions_to_export: The remote functions to export when + the worker gets connected. + _actors_to_export: The actors to export when the worker gets + connected. + _function_execution_info: The map from driver_id to finction_id + and execution_info. + _num_task_executions: The map from driver_id to function + execution times. + """ + + def __init__(self, worker): + self._worker = worker + self._functions_to_export = [] + self._actors_to_export = [] + # This field is a dictionary that maps a driver ID to a dictionary of + # functions (and information about those functions) that have been + # registered for that driver (this inner dictionary maps function IDs + # to a FunctionExecutionInfo object. This should only be used on + # workers that execute remote functions. + self._function_execution_info = defaultdict(lambda: {}) + self._num_task_executions = defaultdict(lambda: {}) + + def increase_task_counter(self, driver_id, function_id): + self._num_task_executions[driver_id][function_id] += 1 + + def get_task_counter(self, driver_id, function_id): + return self._num_task_executions[driver_id][function_id] + + def export_cached(self): + """Export cached remote functions + + Note: this should be called only once when worker is connected. + """ + for remote_function in self._functions_to_export: + self._do_export(remote_function) + self._functions_to_export = None + for info in self._actors_to_export: + (key, actor_class_info) = info + self._publish_actor_class_to_key(key, actor_class_info) + + def reset_cache(self): + self._functions_to_export = [] + self._actors_to_export = [] + + def export(self, remote_function): + """Export a remote function. + + Args: + remote_function: the RemoteFunction object. + """ + if self._worker.mode is None: + # If the worker isn't connected, cache the function + # and export it later. + self._functions_to_export.append(remote_function) + return + if self._worker.mode != ray.worker.SCRIPT_MODE: + # Don't need to export if the worker is not a driver. + return + self._do_export(remote_function) + + def _do_export(self, remote_function): + """Pickle a remote function and export it to redis. + + Args: + remote_function: the RemoteFunction object. + """ + # Work around limitations of Python pickling. + function = remote_function._function + function_name_global_valid = function.__name__ in function.__globals__ + function_name_global_value = function.__globals__.get( + function.__name__) + # Allow the function to reference itself as a global variable + if not is_cython(function): + function.__globals__[function.__name__] = remote_function + try: + pickled_function = pickle.dumps(function) + finally: + # Undo our changes + if function_name_global_valid: + function.__globals__[function.__name__] = ( + function_name_global_value) + else: + del function.__globals__[function.__name__] + + check_oversized_pickle(pickled_function, + remote_function._function_name, + "remote function", self._worker) + + key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" + + remote_function._function_id) + self._worker.redis_client.hmset( + key, { + "driver_id": self._worker.task_driver_id.id(), + "function_id": remote_function._function_id, + "name": remote_function._function_name, + "module": function.__module__, + "function": pickled_function, + "max_calls": remote_function._max_calls + }) + self._worker.redis_client.rpush("Exports", key) + + def fetch_and_register_remote_function(self, key): + """Import a remote function.""" + (driver_id, function_id_str, function_name, serialized_function, + num_return_vals, module, resources, + max_calls) = self._worker.redis_client.hmget(key, [ + "driver_id", "function_id", "name", "function", "num_return_vals", + "module", "resources", "max_calls" + ]) + function_id = ray.ObjectID(function_id_str) + function_name = decode(function_name) + max_calls = int(max_calls) + module = decode(module) + + # This is a placeholder in case the function can't be unpickled. This + # will be overwritten if the function is successfully registered. + def f(): + raise Exception("This function was not imported properly.") + + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=f, function_name=function_name, max_calls=max_calls)) + self._num_task_executions[driver_id][function_id.id()] = 0 + + try: + function = pickle.loads(serialized_function) + except Exception: + # If an exception was thrown when the remote function was imported, + # we record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, + traceback_str, + driver_id=driver_id, + data={ + "function_id": function_id.id(), + "function_name": function_name + }) + else: + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python script + # was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + function.__module__ = module + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=max_calls)) + # Add the function to the function table. + self._worker.redis_client.rpush( + b"FunctionTable:" + function_id.id(), self._worker.worker_id) + + def get_execution_info(self, driver_id, function_id): + """Get the FunctionExecutionInfo of a remote function. + + Args: + driver_id: ID of the driver that the function belongs to. + function_id: ID of the function to get. + + Returns: + A FunctionExecutionInfo object. + """ + # Wait until the function to be executed has actually been registered + # on this worker. We will push warnings to the user if we spend too + # long in this loop. + with profiling.profile("wait_for_function", worker=self._worker): + self._wait_for_function(function_id, driver_id) + return self._function_execution_info[driver_id][function_id.id()] + + def _wait_for_function(self, function_id, driver_id, timeout=10): + """Wait until the function to be executed is present on this worker. + + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate + a problem somewhere and we will push an error message to the user. + + If this worker is an actor, then this will wait until the actor has + been defined. + + Args: + function_id (str): The ID of the function that we want to execute. + driver_id (str): The ID of the driver to push the error message to + if this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with self._worker.lock: + if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID + and (function_id.id() in + self._function_execution_info[driver_id])): + break + elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( + self._worker.actor_id in self._worker.actors): + break + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray.") + if not warning_sent: + ray.utils.push_error_to_driver( + self._worker, + ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, + warning_message, + driver_id=driver_id) + warning_sent = True + time.sleep(0.001) + + @classmethod + def compute_actor_method_function_id(cls, class_name, attr): + """Get the function ID corresponding to an actor method. + + Args: + class_name (str): The class name of the actor. + attr (str): The attribute name of the method. + + Returns: + Function ID corresponding to the method. + """ + function_id_hash = hashlib.sha1() + function_id_hash.update(class_name.encode("ascii")) + function_id_hash.update(attr.encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == ray_constants.ID_SIZE + return ray.ObjectID(function_id) + + def _publish_actor_class_to_key(self, key, actor_class_info): + """Push an actor class definition to Redis. + + The is factored out as a separate function because it is also called + on cached actor class definitions when a worker connects for the first + time. + + Args: + key: The key to store the actor class info at. + actor_class_info: Information about the actor class. + worker: The worker to use to connect to Redis. + """ + # We set the driver ID here because it may not have been available when + # the actor class was defined. + actor_class_info["driver_id"] = self._worker.task_driver_id.id() + self._worker.redis_client.hmset(key, actor_class_info) + self._worker.redis_client.rpush("Exports", key) + + def export_actor_class(self, class_id, Class, actor_method_names, + checkpoint_interval): + key = b"ActorClass:" + class_id + actor_class_info = { + "class_name": Class.__name__, + "module": Class.__module__, + "class": pickle.dumps(Class), + "checkpoint_interval": checkpoint_interval, + "actor_method_names": json.dumps(list(actor_method_names)) + } + + check_oversized_pickle(actor_class_info["class"], + actor_class_info["class_name"], "actor", + self._worker) + + if self._worker.mode is None: + # This means that 'ray.init()' has not been called yet and so we + # must cache the actor class definition and export it when + # 'ray.init()' is called. + assert self._actors_to_export is not None + self._actors_to_export.append((key, actor_class_info)) + # This caching code path is currently not used because we only + # export actor class definitions lazily when we instantiate the + # actor for the first time. + assert False, "This should be unreachable." + else: + self._publish_actor_class_to_key(key, actor_class_info) + # TODO(rkn): Currently we allow actor classes to be defined + # within tasks. I tried to disable this, but it may be necessary + # because of https://github.com/ray-project/ray/issues/1146. + + def fetch_and_register_actor(self, actor_class_key): + """Import an actor. + + This will be called by the worker's import thread when the worker + receives the actor_class export, assuming that the worker is an actor + for that class. + + Args: + actor_class_key: The key in Redis to use to fetch the actor. + worker: The worker to use. + """ + actor_id_str = self._worker.actor_id + (driver_id, class_id, class_name, module, pickled_class, + checkpoint_interval, + actor_method_names) = self._worker.redis_client.hmget( + actor_class_key, [ + "driver_id", "class_id", "class_name", "module", "class", + "checkpoint_interval", "actor_method_names" + ]) + + class_name = decode(class_name) + module = decode(module) + checkpoint_interval = int(checkpoint_interval) + actor_method_names = json.loads(decode(actor_method_names)) + + # In Python 2, json loads strings as unicode, so convert them back to + # strings. + if sys.version_info < (3, 0): + actor_method_names = [ + method_name.encode("ascii") + for method_name in actor_method_names + ] + + # Create a temporary actor with some temporary methods so that if + # the actor fails to be unpickled, the temporary actor can be used + # (just to produce error messages and to prevent the driver from + # hanging). + class TemporaryActor(object): + pass + + self._worker.actors[actor_id_str] = TemporaryActor() + self._worker.actor_checkpoint_interval = checkpoint_interval + + def temporary_actor_method(*xs): + raise Exception( + "The actor with name {} failed to be imported, " + "and so cannot execute this method".format(class_name)) + + # Register the actor method executors. + for actor_method_name in actor_method_names: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + temporary_executor = self._make_actor_method_executor( + actor_method_name, + temporary_actor_method, + actor_imported=False) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=temporary_executor, + function_name=actor_method_name, + max_calls=0)) + self._num_task_executions[driver_id][function_id] = 0 + + try: + unpickled_class = pickle.loads(pickled_class) + self._worker.actor_class = unpickled_class + except Exception: + # If an exception was thrown when the actor was imported, we record + # the traceback and notify the scheduler of the failure. + traceback_str = ray.utils.format_error_message( + traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_ACTOR_PUSH_ERROR, + traceback_str, + driver_id, + data={"actor_id": actor_id_str}) + # TODO(rkn): In the future, it might make sense to have the worker + # exit here. However, currently that would lead to hanging if + # someone calls ray.get on a method invoked on the actor. + else: + # TODO(pcm): Why is the below line necessary? + unpickled_class.__module__ = module + self._worker.actors[actor_id_str] = unpickled_class.__new__( + unpickled_class) + + actor_methods = inspect.getmembers( + unpickled_class, predicate=is_function_or_method) + for actor_method_name, actor_method in actor_methods: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + executor = self._make_actor_method_executor( + actor_method_name, actor_method, actor_imported=True) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=executor, + function_name=actor_method_name, + max_calls=0)) + # We do not set function_properties[driver_id][function_id] + # because we currently do need the actor worker to submit new + # tasks for the actor. + + def _make_actor_method_executor(self, method_name, method, actor_imported): + """Make an executor that wraps a user-defined actor method. + + The wrapped method updates the worker's internal state and performs any + necessary checkpointing operations. + + Args: + worker (Worker): The worker that is executing the actor. + method_name (str): The name of the actor method. + method (instancemethod): The actor method to wrap. This should be a + method defined on the actor class and should therefore take an + instance of the actor as the first argument. + actor_imported (bool): Whether the actor has been imported. + Checkpointing operations will not be run if this is set to + False. + + Returns: + A function that executes the given actor method on the worker's + stored instance of the actor. The function also updates the + worker's internal state to record the executed method. + """ + + def actor_method_executor(dummy_return_id, actor, *args): + # Update the actor's task counter to reflect the task we're about + # to execute. + self._worker.actor_task_counter += 1 + + # If this is the first task to execute on the actor, try to resume + # from a checkpoint. + if actor_imported and self._worker.actor_task_counter == 1: + checkpoint_resumed = ray.actor.restore_and_log_checkpoint( + self._worker, actor) + if checkpoint_resumed: + # NOTE(swang): Since we did not actually execute the + # __init__ method, this will put None as the return value. + # If the __init__ method is supposed to return multiple + # values, an exception will be logged. + return + + # Determine whether we should checkpoint the actor. + checkpointing_on = (actor_imported + and self._worker.actor_checkpoint_interval > 0) + # We should checkpoint the actor if user checkpointing is on, we've + # executed checkpoint_interval tasks since the last checkpoint, and + # the method we're about to execute is not a checkpoint. + save_checkpoint = (checkpointing_on + and (self._worker.actor_task_counter % + self._worker.actor_checkpoint_interval == 0 + and method_name != "__ray_checkpoint__")) + + # Execute the assigned method and save a checkpoint if necessary. + try: + if is_class_method(method): + method_returns = method(*args) + else: + method_returns = method(actor, *args) + except Exception: + # Save the checkpoint before allowing the method exception + # to be thrown. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + raise + else: + # Save the checkpoint before returning the method's return + # values. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + return method_returns + + return actor_method_executor diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 2616e064d850f..347f7ab9f8064 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -4,19 +4,6 @@ import flatbuffers -from ray.core.generated.ResultTableReply import ResultTableReply -from ray.core.generated.SubscribeToNotificationsReply \ - import SubscribeToNotificationsReply -from ray.core.generated.TaskExecutionDependencies import \ - TaskExecutionDependencies -from ray.core.generated.TaskReply import TaskReply -from ray.core.generated.DriverTableMessage import DriverTableMessage -from ray.core.generated.LocalSchedulerInfoMessage import \ - LocalSchedulerInfoMessage -from ray.core.generated.SubscribeToDBClientTableReply import \ - SubscribeToDBClientTableReply -from ray.core.generated.TaskInfo import TaskInfo - import ray.core.generated.ErrorTableData from ray.core.generated.GcsTableEntry import GcsTableEntry @@ -24,6 +11,7 @@ from ray.core.generated.ErrorTableData import ErrorTableData from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ObjectTableData import ObjectTableData from ray.core.generated.ray.protocol.Task import Task @@ -32,31 +20,17 @@ from ray.core.generated.TablePubsub import TablePubsub __all__ = [ - "SubscribeToNotificationsReply", "ResultTableReply", - "TaskExecutionDependencies", "TaskReply", "DriverTableMessage", - "LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo", "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", - "DriverTableData", "ProfileTableData", "ObjectTableData", "Task", - "TablePrefix", "TablePubsub", "construct_error_message" + "HeartbeatBatchTableData", "DriverTableData", "ProfileTableData", + "ObjectTableData", "Task", "TablePrefix", "TablePubsub", + "construct_error_message" ] -# These prefixes must be kept up-to-date with the definitions in -# ray_redis_module.cc. -DB_CLIENT_PREFIX = "CL:" -TASK_PREFIX = "TT:" -OBJECT_CHANNEL_PREFIX = "OC:" -OBJECT_INFO_PREFIX = "OI:" -OBJECT_LOCATION_PREFIX = "OL:" FUNCTION_PREFIX = "RemoteFunction:" -# These prefixes must be kept up-to-date with the definitions in -# common/state/redis.cc -LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers" -PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" -DRIVER_DEATH_CHANNEL = b"driver_deaths" - # xray heartbeats XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") # xray driver updates XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") diff --git a/python/ray/global_scheduler/__init__.py b/python/ray/global_scheduler/__init__.py deleted file mode 100644 index 25e4d2cf6490c..0000000000000 --- a/python/ray/global_scheduler/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from .global_scheduler_services import start_global_scheduler - -__all__ = ["start_global_scheduler"] diff --git a/python/ray/global_scheduler/build/.gitkeep b/python/ray/global_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py deleted file mode 100644 index 7e3d019ffa980..0000000000000 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import subprocess -import time - - -def start_global_scheduler(redis_address, - node_ip_address, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a global scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will - run on. - use_valgrind (bool): True if the global scheduler should be started - inside of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the global scheduler should be started - inside a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - - Return: - The process ID of the global scheduler process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - global_scheduler_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/global_scheduler/global_scheduler") - command = [ - global_scheduler_executable, "-r", redis_address, "-h", node_ip_address - ] - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return pid diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py deleted file mode 100644 index 37aad62ee1b01..0000000000000 --- a/python/ray/global_scheduler/test/test.py +++ /dev/null @@ -1,332 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import os -import random -import signal -import sys -import time -import unittest - -# The ray import must come before the pyarrow import because ray modifies the -# python path so that the right version of pyarrow is found. -import ray.global_scheduler as global_scheduler -import ray.local_scheduler as local_scheduler -import ray.plasma as plasma -from ray.plasma.utils import create_object -from ray import services -from ray.experimental import state -import ray.ray_constants as ray_constants -import pyarrow as pa - -USE_VALGRIND = False -PLASMA_STORE_MEMORY = 1000000000 -NUM_CLUSTER_NODES = 2 - -NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff" -NIL_OBJECT_ID = ray_constants.ID_SIZE * b"\xff" -NIL_ACTOR_ID = ray_constants.ID_SIZE * b"\xff" - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def new_port(): - return random.randint(10000, 65535) - - -class TestGlobalScheduler(unittest.TestCase): - def setUp(self): - # Start one Redis server and N pairs of (plasma, local_scheduler) - self.node_ip_address = "127.0.0.1" - redis_address, redis_shards = services.start_redis( - self.node_ip_address) - redis_port = services.get_port(redis_address) - time.sleep(0.1) - # Create a client for the global state store. - self.state = state.GlobalState() - self.state._initialize_global_state(self.node_ip_address, redis_port) - - # Start one global scheduler. - self.p1 = global_scheduler.start_global_scheduler( - redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) - self.plasma_store_pids = [] - self.plasma_manager_pids = [] - self.local_scheduler_pids = [] - self.plasma_clients = [] - self.local_scheduler_clients = [] - - for i in range(NUM_CLUSTER_NODES): - # Start the Plasma store. Plasma store name is randomly generated. - plasma_store_name, p2 = plasma.start_plasma_store() - self.plasma_store_pids.append(p2) - # Start the Plasma manager. - # Assumption: Plasma manager name and port are randomly generated - # by the plasma module. - manager_info = plasma.start_plasma_manager(plasma_store_name, - redis_address) - plasma_manager_name, p3, plasma_manager_port = manager_info - self.plasma_manager_pids.append(p3) - plasma_address = "{}:{}".format(self.node_ip_address, - plasma_manager_port) - plasma_client = pa.plasma.connect(plasma_store_name, - plasma_manager_name, 64) - self.plasma_clients.append(plasma_client) - # Start the local scheduler. - local_scheduler_name, p4 = local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name=plasma_manager_name, - plasma_address=plasma_address, - redis_address=redis_address, - static_resources={"CPU": 10}) - # Connect to the scheduler. - local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_WORKER_ID, False, random_task_id(), - False) - self.local_scheduler_clients.append(local_scheduler_client) - self.local_scheduler_pids.append(p4) - - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - for p2 in self.plasma_store_pids: - self.assertEqual(p2.poll(), None) - for p3 in self.plasma_manager_pids: - self.assertEqual(p3.poll(), None) - for p4 in self.local_scheduler_pids: - self.assertEqual(p4.poll(), None) - - redis_processes = services.all_processes[ - services.PROCESS_TYPE_REDIS_SERVER] - for redis_process in redis_processes: - self.assertEqual(redis_process.poll(), None) - - # Kill the global scheduler. - if USE_VALGRIND: - self.p1.send_signal(signal.SIGTERM) - self.p1.wait() - if self.p1.returncode != 0: - os._exit(-1) - else: - self.p1.kill() - # Kill local schedulers, plasma managers, and plasma stores. - for p2 in self.local_scheduler_pids: - p2.kill() - for p3 in self.plasma_manager_pids: - p3.kill() - for p4 in self.plasma_store_pids: - p4.kill() - # Kill Redis. In the event that we are using valgrind, this needs to - # happen after we kill the global scheduler. - while redis_processes: - redis_process = redis_processes.pop() - redis_process.kill() - - def get_plasma_manager_id(self): - """Get the db_client_id with client_type equal to plasma_manager. - - Iterates over all the client table keys, gets the db_client_id for the - client with client_type matching plasma_manager. Strips the client - table prefix. TODO(atumanov): write a separate function to get all - plasma manager client IDs. - - Returns: - The db_client_id if one is found and otherwise None. - """ - db_client_id = None - - client_list = self.state.client_table()[self.node_ip_address] - for client in client_list: - if client["ClientType"] == "plasma_manager": - db_client_id = client["DBClientID"] - break - - return db_client_id - - def test_task_default_resources(self): - task1 = local_scheduler.Task( - random_driver_id(), random_function_id(), [random_object_id()], 0, - random_task_id(), 0) - self.assertEqual(task1.required_resources(), {"CPU": 1}) - task2 = local_scheduler.Task( - random_driver_id(), random_function_id(), [random_object_id()], 0, - random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID), - local_scheduler.ObjectID(NIL_OBJECT_ID), - local_scheduler.ObjectID(NIL_ACTOR_ID), - local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], { - "CPU": 1, - "GPU": 2 - }) - self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2}) - - def test_redis_only_single_task(self): - # Tests global scheduler functionality by interacting with Redis and - # checking task state transitions in Redis only. TODO(atumanov): - # implement. - - # Check precondition for this test: - # There should be 2n+1 db clients: the global scheduler + one local - # scheduler and one plasma per node. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - db_client_id = self.get_plasma_manager_id() - assert (db_client_id is not None) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_single_task(self): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - - num_return_vals = [0, 1, 2, 3, 5, 10] - # Insert the object into Redis. - data_size = 0xf1f0 - metadata_size = 0x40 - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) - - # Sleep before submitting task to local scheduler. - time.sleep(0.1) - # Submit a task to Redis. - task = local_scheduler.Task( - random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - time.sleep(0.1) - # There should now be a task in Redis, and it should get assigned to - # the local scheduler - num_retries = 10 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), 1) - if len(task_entries) == 1: - task_id, task = task_entries.popitem() - task_status = task["State"] - self.assertTrue(task_status in [ - state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED - ]) - if task_status == state.TASK_STATUS_QUEUED: - break - else: - print(task_status) - print("The task has not been scheduled yet, trying again.") - num_retries -= 1 - time.sleep(1) - - if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED: - # Failed to submit and schedule a single task -- bail. - self.tearDown() - sys.exit(1) - - def integration_many_tasks_helper(self, timesync=True): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - num_return_vals = [0, 1, 2, 3, 5, 10] - - # Submit a bunch of tasks to Redis. - num_tasks = 1000 - for _ in range(num_tasks): - # Create a new object for each task. - data_size = np.random.randint(1 << 12) - metadata_size = np.random.randint(1 << 9) - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) - if timesync: - # Give 10ms for object info handler to fire (long enough to - # yield CPU). - time.sleep(0.010) - task = local_scheduler.Task( - random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - # Check that there are the correct number of tasks in Redis and that - # they all get assigned to the local scheduler. - num_retries = 20 - num_tasks_done = 0 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), num_tasks) - # First, check if all tasks made it to Redis. - if len(task_entries) == num_tasks: - task_statuses = [ - task_entry["State"] - for task_entry in task_entries.values() - ] - self.assertTrue( - all(status in [ - state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED - ] for status in task_statuses)) - num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) - num_tasks_scheduled = task_statuses.count( - state.TASK_STATUS_SCHEDULED) - num_tasks_waiting = task_statuses.count( - state.TASK_STATUS_WAITING) - print("tasks in Redis = {}, tasks waiting = {}, " - "tasks scheduled = {}, " - "tasks queued = {}, retries left = {}".format( - len(task_entries), num_tasks_waiting, - num_tasks_scheduled, num_tasks_done, num_retries)) - if all(status == state.TASK_STATUS_QUEUED - for status in task_statuses): - # We're done, so pass. - break - num_retries -= 1 - time.sleep(0.1) - - # Tasks can either be queued or in the global scheduler due to - # spillback. - self.assertEqual(num_tasks_done + num_tasks_waiting, num_tasks) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_many_tasks_handler_sync(self): - self.integration_many_tasks_helper(timesync=True) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_many_tasks(self): - # More realistic case: should handle out of order object and task - # notifications. - self.integration_many_tasks_helper(timesync=False) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 659cdf1ce281e..70dba322370bb 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -88,7 +88,8 @@ def _process_key(self, key): if key.startswith(b"RemoteFunction"): with profiling.profile( "register_remote_function", worker=self.worker): - self.fetch_and_register_remote_function(key) + (self.worker.function_actor_manager. + fetch_and_register_remote_function(key)) elif key.startswith(b"FunctionsToRun"): with profiling.profile( "fetch_and_run_function", worker=self.worker): @@ -103,65 +104,13 @@ def _process_key(self, key): else: raise Exception("This code should be unreachable.") - def fetch_and_register_remote_function(self, key): - """Import a remote function.""" - from ray.worker import FunctionExecutionInfo - (driver_id, function_id_str, function_name, serialized_function, - num_return_vals, module, resources, - max_calls) = self.redis_client.hmget(key, [ - "driver_id", "function_id", "name", "function", "num_return_vals", - "module", "resources", "max_calls" - ]) - function_id = ray.ObjectID(function_id_str) - function_name = utils.decode(function_name) - max_calls = int(max_calls) - module = utils.decode(module) - - # This is a placeholder in case the function can't be unpickled. This - # will be overwritten if the function is successfully registered. - def f(): - raise Exception("This function was not imported properly.") - - self.worker.function_execution_info[driver_id][function_id.id()] = ( - FunctionExecutionInfo( - function=f, function_name=function_name, max_calls=max_calls)) - self.worker.num_task_executions[driver_id][function_id.id()] = 0 - - try: - function = pickle.loads(serialized_function) - except Exception: - # If an exception was thrown when the remote function was imported, - # we record the traceback and notify the scheduler of the failure. - traceback_str = utils.format_error_message(traceback.format_exc()) - # Log the error message. - utils.push_error_to_driver( - self.worker, - ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, - traceback_str, - driver_id=driver_id, - data={ - "function_id": function_id.id(), - "function_name": function_name - }) - else: - # TODO(rkn): Why is the below line necessary? - function.__module__ = module - self.worker.function_execution_info[driver_id][ - function_id.id()] = (FunctionExecutionInfo( - function=function, - function_name=function_name, - max_calls=max_calls)) - # Add the function to the function table. - self.redis_client.rpush(b"FunctionTable:" + function_id.id(), - self.worker.worker_id) - def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" (driver_id, serialized_function, run_on_other_drivers) = self.redis_client.hmget( key, ["driver_id", "function", "run_on_other_drivers"]) - if (run_on_other_drivers == "False" + if (utils.decode(run_on_other_drivers) == "False" and self.worker.mode == ray.SCRIPT_MODE and driver_id != self.worker.task_driver_id.id()): return diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 062d633ee44bd..7772974319aea 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -import ray.local_scheduler +import ray.raylet import ray.worker from ray import profiling @@ -42,7 +42,4 @@ def free(object_ids, local_only=False, worker=None): if len(object_ids) == 0: return - if worker.use_raylet: - worker.local_scheduler_client.free(object_ids, local_only) - else: - raise Exception("Free is not supported in legacy backend.") + worker.local_scheduler_client.free(object_ids, local_only) diff --git a/python/ray/local_scheduler/build/.gitkeep b/python/ray/local_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py deleted file mode 100644 index f7847ce551b0f..0000000000000 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import multiprocessing -import os -import random -import subprocess -import sys -import time - - -def random_name(): - return str(random.randint(0, 99999999)) - - -def start_local_scheduler(plasma_store_name, - plasma_manager_name=None, - worker_path=None, - plasma_address=None, - node_ip_address="127.0.0.1", - redis_address=None, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None, - static_resources=None, - num_workers=0): - """Start a local scheduler process. - - Args: - plasma_store_name (str): The name of the plasma store socket to connect - to. - plasma_manager_name (str): The name of the plasma manager to connect - to. This does not need to be provided, but if it is, then the Redis - address must be provided as well. - worker_path (str): The path of the worker script to use when the local - scheduler starts up new workers. - plasma_address (str): The address of the plasma manager to connect to. - This is only used by the global scheduler to figure out which - plasma managers are connected to which local schedulers. - node_ip_address (str): The address of the node that this local - scheduler is running on. - redis_address (str): The address of the Redis instance to connect to. - If this is not provided, then the local scheduler will not connect - to Redis. - use_valgrind (bool): True if the local scheduler should be started - inside of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the local scheduler should be started - inside a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - static_resources: A dictionary specifying the local scheduler's - resource capacities. This maps resource names (strings) to - integers or floats. - num_workers (int): The number of workers that the local scheduler - should start. - - Return: - A tuple of the name of the local scheduler socket and the process ID of - the local scheduler process. - """ - if (plasma_manager_name is None) != (redis_address is None): - raise Exception("If one of the plasma_manager_name and the " - "redis_address is provided, then both must be " - "provided.") - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - local_scheduler_executable = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "../core/src/local_scheduler/local_scheduler") - local_scheduler_name = "/tmp/scheduler{}".format(random_name()) - command = [ - local_scheduler_executable, "-s", local_scheduler_name, "-p", - plasma_store_name, "-h", node_ip_address, "-n", - str(num_workers) - ] - if plasma_manager_name is not None: - command += ["-m", plasma_manager_name] - if worker_path is not None: - assert plasma_store_name is not None - assert plasma_manager_name is not None - assert redis_address is not None - start_worker_command = ("{} {} " - "--node-ip-address={} " - "--object-store-name={} " - "--object-store-manager-name={} " - "--local-scheduler-name={} " - "--redis-address={}".format( - sys.executable, worker_path, - node_ip_address, plasma_store_name, - plasma_manager_name, local_scheduler_name, - redis_address)) - command += ["-w", start_worker_command] - if redis_address is not None: - command += ["-r", redis_address] - if plasma_address is not None: - command += ["-a", plasma_address] - if static_resources is not None: - resource_argument = "" - for resource_name, resource_quantity in static_resources.items(): - assert (isinstance(resource_quantity, int) - or isinstance(resource_quantity, float)) - resource_argument = ",".join([ - resource_name + "," + str(resource_quantity) - for resource_name, resource_quantity in static_resources.items() - ]) - else: - resource_argument = "CPU,{}".format(multiprocessing.cpu_count()) - command += ["-c", resource_argument] - - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return local_scheduler_name, pid diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py deleted file mode 100644 index b35d609de6e0e..0000000000000 --- a/python/ray/local_scheduler/test/test.py +++ /dev/null @@ -1,206 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import os -import signal -import sys -import threading -import time -import unittest - -import ray.local_scheduler as local_scheduler -import ray.plasma as plasma -import ray.ray_constants as ray_constants -import pyarrow as pa - -USE_VALGRIND = False - -NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff" - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -class TestLocalSchedulerClient(unittest.TestCase): - def setUp(self): - # Start Plasma store. - plasma_store_name, self.p1 = plasma.start_plasma_store() - self.plasma_client = pa.plasma.connect(plasma_store_name, "", 0) - # Start a local scheduler. - scheduler_name, self.p2 = local_scheduler.start_local_scheduler( - plasma_store_name, use_valgrind=USE_VALGRIND) - # Connect to the scheduler. - self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_WORKER_ID, False, random_task_id(), False) - - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - self.assertEqual(self.p2.poll(), None) - - # Kill Plasma. - self.p1.kill() - # Kill the local scheduler. - if USE_VALGRIND: - self.p2.send_signal(signal.SIGTERM) - self.p2.wait() - if self.p2.returncode != 0: - os._exit(-1) - else: - self.p2.kill() - - def test_submit_and_get_task(self): - function_id = random_function_id() - object_ids = [random_object_id() for i in range(256)] - # Create and seal the objects in the object store so that we can - # schedule all of the subsequent tasks. - for object_id in object_ids: - self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0) - self.plasma_client.seal(pa.plasma.ObjectID(object_id.id())) - # Define some arguments to use for the tasks. - args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], - 1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [ - 1, 1.3, 1 << 100, "hi", u"hi", [1, 2] - ], object_ids[:1], object_ids[:2], object_ids[:3], - object_ids[:4], object_ids[:5], object_ids[:10], - object_ids[:100], object_ids[:256], [1, object_ids[0]], [ - object_ids[0], "a" - ], [1, object_ids[0], "a"], [ - object_ids[0], 1, object_ids[1], "a" - ], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids] - - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) - # Submit a task. - self.local_scheduler_client.submit(task) - # Get the task. - new_task = self.local_scheduler_client.get_task() - self.assertEqual(task.function_id().id(), - new_task.function_id().id()) - retrieved_args = new_task.arguments() - returns = new_task.returns() - self.assertEqual(len(args), len(retrieved_args)) - self.assertEqual(num_return_vals, len(returns)) - for i in range(len(retrieved_args)): - if isinstance(args[i], local_scheduler.ObjectID): - self.assertEqual(args[i].id(), retrieved_args[i].id()) - else: - self.assertEqual(args[i], retrieved_args[i]) - - # Submit all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) - self.local_scheduler_client.submit(task) - # Get all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - new_task = self.local_scheduler_client.get_task() - - def test_scheduling_when_objects_ready(self): - # Create a task and submit it. - object_id = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id], 0, random_task_id(), 0) - self.local_scheduler_client.submit(task) - - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - - t = threading.Thread(target=get_task) - t.start() - # Sleep to give the thread time to call get_task. - time.sleep(0.1) - # Create and seal the object ID in the object store. This should - # trigger a scheduling event. - self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0) - self.plasma_client.seal(pa.plasma.ObjectID(object_id.id())) - # Wait until the thread finishes so that we know the task was - # scheduled. - t.join() - - def test_scheduling_when_objects_evicted(self): - # Create a task with two dependencies and submit it. - object_id1 = random_object_id() - object_id2 = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id1, object_id2], 0, - random_task_id(), 0) - self.local_scheduler_client.submit(task) - - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - - t = threading.Thread(target=get_task) - t.start() - - # Make one of the dependencies available. - buf = self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id())) - # Release the object. - del buf - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Force eviction of the first dependency. - self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Check that the first object dependency was evicted. - object1 = self.plasma_client.get_buffers( - [pa.plasma.ObjectID(object_id1.id())], timeout_ms=0) - self.assertEqual(object1, [None]) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the second dependency. - self.plasma_client.create(pa.plasma.ObjectID(object_id2.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id2.id())) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the first dependency again. Both dependencies are now - # available. - self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id())) - - # Wait until the thread finishes so that we know the task was - # scheduled. - t.join() - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 13a62a98a322b..2cd6fc40a0f56 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -35,11 +35,15 @@ class LogMonitor(object): handle for that file. """ - def __init__(self, redis_ip_address, redis_port, node_ip_address): + def __init__(self, + redis_ip_address, + redis_port, + node_ip_address, + redis_password=None): """Initialize the log monitor object.""" self.node_ip_address = node_ip_address self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) self.log_files = {} self.log_file_handles = {} self.files_to_ignore = set() @@ -130,6 +134,12 @@ def run(self): required=True, type=str, help="The IP address of the node this process is on.") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -151,6 +161,9 @@ def run(self): redis_ip_address = get_ip_address(args.redis_address) redis_port = get_port(args.redis_address) - log_monitor = LogMonitor(redis_ip_address, redis_port, - args.node_ip_address) + log_monitor = LogMonitor( + redis_ip_address, + redis_port, + args.node_ip_address, + redis_password=args.redis_password) log_monitor.run() diff --git a/python/ray/memory_monitor.py b/python/ray/memory_monitor.py new file mode 100644 index 0000000000000..a52f98d7077df --- /dev/null +++ b/python/ray/memory_monitor.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import time + +try: + import psutil +except ImportError: + psutil = None + +logger = logging.getLogger(__name__) + + +class RayOutOfMemoryError(Exception): + def __init__(self, msg): + Exception.__init__(self, msg) + + @staticmethod + def get_message(used_gb, total_gb, threshold): + pids = psutil.pids() + proc_stats = [] + for pid in pids: + proc = psutil.Process(pid) + proc_stats.append((proc.memory_info().rss, pid, proc.cmdline())) + proc_str = "PID\tMEM\tCOMMAND" + for rss, pid, cmdline in sorted(proc_stats, reverse=True)[:5]: + proc_str += "\n{}\t{}GB\t{}".format( + pid, round(rss / 1e9, 2), " ".join(cmdline)[:100].strip()) + return ("More than {}% of the memory on ".format(int( + 100 * threshold)) + "node {} is used ({} / {} GB). ".format( + os.uname()[1], round(used_gb, 2), round(total_gb, 2)) + + "The top 5 memory consumers are:\n\n{}".format(proc_str) + + "\n\nIn addition, ~{} GB of shared memory is ".format( + round(psutil.virtual_memory().shared / 1e9, 2)) + + "currently being used by the Ray object store. You can set " + "the object store size with the `object_store_memory` " + "parameter when starting Ray, and the max Redis size with " + "`redis_max_memory`.") + + +class MemoryMonitor(object): + """Helper class for raising errors on low memory. + + This presents a much cleaner error message to users than what would happen + if we actually ran out of memory. + """ + + def __init__(self, error_threshold=0.95, check_interval=1): + # Note: it takes ~50us to check the memory usage through psutil, so + # throttle this check at most once a second or so. + self.check_interval = check_interval + self.last_checked = time.time() + self.error_threshold = error_threshold + if not psutil: + logger.warning( + "WARNING: Not monitoring node memory since `psutil` is not " + "installed. Install this with `pip install psutil` " + "(or ray[debug]) to enable debugging of memory-related " + "crashes.") + + def raise_if_low_memory(self): + if not psutil: + return # nothing we can do + + if "RAY_DEBUG_DISABLE_MEMORY_MONITOR" in os.environ: + return # escape hatch, not intended for user use + + if time.time() - self.last_checked > self.check_interval: + self.last_checked = time.time() + total_gb = psutil.virtual_memory().total / 1e9 + used_gb = total_gb - psutil.virtual_memory().available / 1e9 + if used_gb > total_gb * self.error_threshold: + raise RayOutOfMemoryError( + RayOutOfMemoryError.get_message(used_gb, total_gb, + self.error_threshold)) + else: + logger.debug("Memory usage is {} / {}".format( + used_gb, total_gb)) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index e5c2279b72333..a37f75de7cf1b 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -3,11 +3,9 @@ from __future__ import print_function import argparse -import binascii import logging import os import time -from collections import Counter, defaultdict import traceback import redis @@ -20,27 +18,6 @@ import ray.ray_constants as ray_constants from ray.services import get_ip_address, get_port from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary -from ray.worker import NIL_ACTOR_ID - -# These variables must be kept in sync with the C codebase. -# common/common.h -NIL_ID = b"\xff" * ray_constants.ID_SIZE - -# common/task.h -TASK_STATUS_LOST = 32 - -# common/redis_module/ray_redis_module.cc -OBJECT_INFO_PREFIX = b"OI:" -OBJECT_LOCATION_PREFIX = b"OL:" -TASK_TABLE_PREFIX = b"TT:" -DB_CLIENT_PREFIX = b"CL:" -DB_CLIENT_TABLE_NAME = b"db_clients" - -# local_scheduler/local_scheduler.h -LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler" - -# plasma/plasma_manager.cc -PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager" # Set up logging. logger = logging.getLogger(__name__) @@ -55,45 +32,24 @@ class Monitor(object): Attributes: redis: A connection to the Redis server. - use_raylet: A bool indicating whether to use the raylet code path or - not. subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. - dead_local_schedulers: A set of the local scheduler IDs of all of the - local schedulers that were up at one point and have died since - then. - live_plasma_managers: A counter mapping live plasma manager IDs to the - number of heartbeats that have passed since we last heard from that - plasma manager. A plasma manager is live if we received a heartbeat - from it at any point, and if it has not timed out. - dead_plasma_managers: A set of the plasma manager IDs of all the plasma - managers that were up at one point and have died since then. """ - def __init__(self, redis_address, redis_port, autoscaling_config): + def __init__(self, + redis_address, + redis_port, + autoscaling_config, + redis_password=None): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state(redis_address, redis_port) - self.use_raylet = self.state.use_raylet + self.state._initialize_global_state( + redis_address, redis_port, redis_password=redis_password) self.redis = redis.StrictRedis( - host=redis_address, port=redis_port, db=0) + host=redis_address, port=redis_port, db=0, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) - if self.use_raylet: - self.shard_subscribe_clients = [] - for redis_client in self.state.redis_clients: - subscribe_client = redis_client.pubsub( - ignore_subscribe_messages=True) - self.shard_subscribe_clients.append(subscribe_client) - else: - # We don't need to subscribe to the shards in legacy Ray. - self.shard_subscribe_clients = [] - # Initialize data structures to keep track of the active database - # clients. - self.dead_local_schedulers = set() - self.live_plasma_managers = Counter() - self.dead_plasma_managers = set() # Keep a mapping from local scheduler client ID to IP address to use # for updating the load metrics. self.local_scheduler_id_to_ip_map = {} @@ -118,7 +74,9 @@ def __init__(self, redis_address, redis_port, autoscaling_config): else: addr_port = addr_port[0].split(b":") self.redis_shard = redis.StrictRedis( - host=addr_port[0], port=addr_port[1]) + host=addr_port[0], + port=addr_port[1], + password=redis_password) try: self.redis_shard.execute_command("HEAD.FLUSH 0") except redis.exceptions.ResponseError as e: @@ -127,367 +85,50 @@ def __init__(self, redis_address, redis_port, autoscaling_config): str(e))) self.issue_gcs_flushes = False - def subscribe(self, channel, primary=True): - """Subscribe to the given channel. + def subscribe(self, channel): + """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. - primary: If True, then we only subscribe to the primary Redis - shard. Otherwise we subscribe to all of the other shards but - not the primary. Raises: Exception: An exception is raised if the subscription fails. """ - if primary: - self.primary_subscribe_client.subscribe(channel) - else: - for subscribe_client in self.shard_subscribe_clients: - subscribe_client.subscribe(channel) - - def cleanup_task_table(self): - """Clean up global state for failed local schedulers. + self.primary_subscribe_client.subscribe(channel) - This marks any tasks that were scheduled on dead local schedulers as - TASK_STATUS_LOST. A local scheduler is deemed dead if it is in - self.dead_local_schedulers. - """ - tasks = self.state.task_table() - num_tasks_updated = 0 - for task_id, task in tasks.items(): - # See if the corresponding local scheduler is alive. - if task["LocalSchedulerID"] not in self.dead_local_schedulers: - continue - - # Remove dummy objects returned by actor tasks from any plasma - # manager. Although the objects may still exist in that object - # store, this deletion makes them effectively unreachable by any - # local scheduler connected to a different store. - # TODO(swang): Actually remove the objects from the object store, - # so that the reconstructed actor can reuse the same object store. - if hex_to_binary(task["TaskSpec"]["ActorID"]) != NIL_ACTOR_ID: - dummy_object_id = task["TaskSpec"]["ReturnObjectIDs"][-1] - obj = self.state.object_table(dummy_object_id) - manager_ids = obj["ManagerIDs"] - if manager_ids is not None: - # The dummy object should exist on at most one plasma - # manager, the manager associated with the local scheduler - # that died. - assert len(manager_ids) <= 1 - # Remove the dummy object from the plasma manager - # associated with the dead local scheduler, if any. - for manager in manager_ids: - ok = self.state._execute_command( - dummy_object_id, "RAY.OBJECT_TABLE_REMOVE", - dummy_object_id.id(), hex_to_binary(manager)) - if ok != b"OK": - logger.warn("Failed to remove object location for " - "dead plasma manager.") - - # If the task is scheduled on a dead local scheduler, mark the - # task as lost. - key = binary_to_object_id(hex_to_binary(task_id)) - ok = self.state._execute_command( - key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), - ray.experimental.state.TASK_STATUS_LOST, NIL_ID, - task["ExecutionDependenciesString"], task["SpillbackCount"]) - if ok != b"OK": - logger.warn("Failed to update lost task for dead scheduler.") - num_tasks_updated += 1 - - if num_tasks_updated > 0: - logger.warn("Marked {} tasks as lost.".format(num_tasks_updated)) - - def cleanup_object_table(self): - """Clean up global state for failed plasma managers. - - This removes dead plasma managers from any location entries in the - object table. A plasma manager is deemed dead if it is in - self.dead_plasma_managers. - """ - # TODO(swang): Also kill the associated plasma store, since it's no - # longer reachable without a plasma manager. - objects = self.state.object_table() - num_objects_removed = 0 - for object_id, obj in objects.items(): - manager_ids = obj["ManagerIDs"] - if manager_ids is None: - continue - for manager in manager_ids: - if manager in self.dead_plasma_managers: - # If the object was on a dead plasma manager, remove that - # location entry. - ok = self.state._execute_command( - object_id, "RAY.OBJECT_TABLE_REMOVE", object_id.id(), - hex_to_binary(manager)) - if ok != b"OK": - logger.warn("Failed to remove object location for " - "dead plasma manager.") - num_objects_removed += 1 - if num_objects_removed > 0: - logger.warn("Marked {} objects as lost." - .format(num_objects_removed)) - - def scan_db_client_table(self): - """Scan the database client table for dead clients. - - After subscribing to the client table, it's necessary to call this - before reading any messages from the subscription channel. This ensures - that we do not miss any notifications for deleted clients that occurred - before we subscribed. - """ - # Exit if we are using the raylet code path because client_table is - # implemented differently. TODO(rkn): Fix this. - if self.use_raylet: - return - - clients = self.state.client_table() - for node_ip_address, node_clients in clients.items(): - for client in node_clients: - db_client_id = client["DBClientID"] - client_type = client["ClientType"] - if client["Deleted"]: - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - self.dead_plasma_managers.add(db_client_id) - - def db_client_notification_handler(self, unused_channel, data): - """Handle a notification from the db_client table from Redis. - - This handler processes notifications from the db_client table. - Notifications should be parsed using the SubscribeToDBClientTableReply - flatbuffer. Deletions are processed, insertions are ignored. Cleanup of - the associated state in the state tables should be handled by the - caller. - """ - notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply. - GetRootAsSubscribeToDBClientTableReply(data, 0)) - db_client_id = binary_to_hex(notification_object.DbClientId()) - client_type = notification_object.ClientType() - is_insertion = notification_object.IsInsertion() - - # If the update was an insertion, we ignore it. - if is_insertion: - return - - # If the update was a deletion, add them to our accounting for dead - # local schedulers and plasma managers. - logger.warn("Removed {}, client ID {}".format(client_type, - db_client_id)) - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - if db_client_id not in self.dead_local_schedulers: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - if db_client_id not in self.dead_plasma_managers: - self.dead_plasma_managers.add(db_client_id) - # Stop tracking this plasma manager's heartbeats, since it's - # already dead. - del self.live_plasma_managers[db_client_id] - - def local_scheduler_info_handler(self, unused_channel, data): - """Handle a local scheduler heartbeat from Redis.""" - - message = (ray.gcs_utils.LocalSchedulerInfoMessage. - GetRootAsLocalSchedulerInfoMessage(data, 0)) - num_resources = message.DynamicResourcesLength() - static_resources = {} - dynamic_resources = {} - for i in range(num_resources): - dyn = message.DynamicResources(i) - static = message.StaticResources(i) - dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value() - static_resources[static.Key().decode("utf-8")] = static.Value() - - # Update the load metrics for this local scheduler. - client_id = binascii.hexlify(message.DbClientId()).decode("utf-8") - ip = self.local_scheduler_id_to_ip_map.get(client_id) - if ip: - self.load_metrics.update(ip, static_resources, dynamic_resources) - else: - logger.warning( - "Warning: could not find ip for client {} in {}.".format( - client_id, self.local_scheduler_id_to_ip_map)) - - def xray_heartbeat_handler(self, unused_channel, data): - """Handle an xray heartbeat message from Redis.""" + def xray_heartbeat_batch_handler(self, unused_channel, data): + """Handle an xray heartbeat batch message from Redis.""" gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( data, 0) heartbeat_data = gcs_entries.Entries(0) - message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData( - heartbeat_data, 0) - num_resources = message.ResourcesAvailableLabelLength() - static_resources = {} - dynamic_resources = {} - for i in range(num_resources): - dyn = message.ResourcesAvailableLabel(i) - static = message.ResourcesTotalLabel(i) - dynamic_resources[dyn] = message.ResourcesAvailableCapacity(i) - static_resources[static] = message.ResourcesTotalCapacity(i) - - # Update the load metrics for this local scheduler. - client_id = ray.utils.binary_to_hex(message.ClientId()) - ip = self.local_scheduler_id_to_ip_map.get(client_id) - if ip: - self.load_metrics.update(ip, static_resources, dynamic_resources) - else: - print("Warning: could not find ip for client {} in {}.".format( - client_id, self.local_scheduler_id_to_ip_map)) - - def plasma_manager_heartbeat_handler(self, unused_channel, data): - """Handle a plasma manager heartbeat from Redis. - - This resets the number of heartbeats that we've missed from this plasma - manager. - """ - # The first ray_constants.ID_SIZE characters are the client ID. - db_client_id = data[:ray_constants.ID_SIZE] - # Reset the number of heartbeats that we've missed from this plasma - # manager. - self.live_plasma_managers[db_client_id] = 0 - - def _entries_for_driver_in_shard(self, driver_id, redis_shard_index): - """Collect IDs of control-state entries for a driver from a shard. - - Args: - driver_id: The ID of the driver. - redis_shard_index: The index of the Redis shard to query. - - Returns: - Lists of IDs: (returned_object_ids, task_ids, put_objects). The - first two are relevant to the driver and are safe to delete. - The last contains all "put" objects in this redis shard; each - element is an (object_id, corresponding task_id) pair. - """ - # TODO(zongheng): consider adding save & restore functionalities. - redis = self.state.redis_clients[redis_shard_index] - task_table_infos = {} # task id -> TaskInfo messages - - # Scan the task table & filter to get the list of tasks belong to this - # driver. Use a cursor in order not to block the redis shards. - for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"): - entry = redis.hgetall(key) - task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo( - entry[b"TaskSpec"], 0) - if driver_id != task_info.DriverId(): - # Ignore tasks that aren't from this driver. - continue - task_table_infos[task_info.TaskId()] = task_info - - # Get the list of objects returned by these tasks. Note these might - # not belong to this redis shard. - returned_object_ids = [] - for task_info in task_table_infos.values(): - returned_object_ids.extend([ - task_info.Returns(i) for i in range(task_info.ReturnsLength()) - ]) - - # Also record all the ray.put()'d objects. - put_objects = [] - for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): - entry = redis.hgetall(key) - if entry[b"is_put"] == "0": - continue - object_id = key.split(OBJECT_INFO_PREFIX)[1] - task_id = entry[b"task"] - put_objects.append((object_id, task_id)) - - return returned_object_ids, task_table_infos.keys(), put_objects - - def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index): - redis = self.state.redis_clients[shard_index] - # Clean up (in the future, save) entries for non-empty objects. - object_ids_locs = set() - object_ids_infos = set() - for object_id in object_ids: - # OL. - obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1) - if obj_loc: - object_ids_locs.add(object_id) - # OI. - obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id) - if obj_info: - object_ids_infos.add(object_id) - - # Form the redis keys to delete. - keys = [TASK_TABLE_PREFIX + k for k in task_ids] - keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs]) - keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos]) - - if not keys: - return - # Remove with best effort. - num_deleted = redis.delete(*keys) - logger.info( - "Removed {} dead redis entries of the driver from redis shard {}.". - format(num_deleted, shard_index)) - if num_deleted != len(keys): - logger.warning( - "Failed to remove {} relevant redis entries" - " from redis shard {}.".format(len(keys) - num_deleted)) - - def _clean_up_entries_for_driver(self, driver_id): - """Remove this driver's object/task entries from all redis shards. - - Specifically, removes control-state entries of: - * all objects (OI and OL entries) created by `ray.put()` from the - driver - * all tasks belonging to the driver. - """ - # TODO(zongheng): handle function_table, client_table, log_files -- - # these are in the metadata redis server, not in the shards. - driver_object_ids = [] - driver_task_ids = [] - all_put_objects = [] - - # Collect relevant ids. - # TODO(zongheng): consider parallelizing this loop. - for shard_index in range(len(self.state.redis_clients)): - returned_object_ids, task_ids, put_objects = \ - self._entries_for_driver_in_shard(driver_id, shard_index) - driver_object_ids.extend(returned_object_ids) - driver_task_ids.extend(task_ids) - all_put_objects.extend(put_objects) - - # For the put objects, keep those from relevant tasks. - driver_task_ids_set = set(driver_task_ids) - for object_id, task_id in all_put_objects: - if task_id in driver_task_ids_set: - driver_object_ids.append(object_id) - - # Partition IDs and distribute to shards. - object_ids_per_shard = defaultdict(list) - task_ids_per_shard = defaultdict(list) - - def ToShardIndex(index): - return binary_to_object_id(index).redis_shard_hash() % len( - self.state.redis_clients) - - for object_id in driver_object_ids: - object_ids_per_shard[ToShardIndex(object_id)].append(object_id) - for task_id in driver_task_ids: - task_ids_per_shard[ToShardIndex(task_id)].append(task_id) - - # TODO(zongheng): consider parallelizing this loop. - for shard_index in range(len(self.state.redis_clients)): - self._clean_up_entries_from_shard( - object_ids_per_shard[shard_index], - task_ids_per_shard[shard_index], shard_index) - - def driver_removed_handler(self, unused_channel, data): - """Handle a notification that a driver has been removed. - This releases any GPU resources that were reserved for that driver in - Redis. - """ - message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage( - data, 0) - driver_id = message.DriverId() - logger.info("Driver {} has been removed.".format( - binary_to_hex(driver_id))) - - self._clean_up_entries_for_driver(driver_id) + message = (ray.gcs_utils.HeartbeatBatchTableData. + GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + + for j in range(message.BatchLength()): + heartbeat_message = message.Batch(j) + + num_resources = heartbeat_message.ResourcesAvailableLabelLength() + static_resources = {} + dynamic_resources = {} + for i in range(num_resources): + dyn = heartbeat_message.ResourcesAvailableLabel(i) + static = heartbeat_message.ResourcesTotalLabel(i) + dynamic_resources[dyn] = ( + heartbeat_message.ResourcesAvailableCapacity(i)) + static_resources[static] = ( + heartbeat_message.ResourcesTotalCapacity(i)) + + # Update the load metrics for this local scheduler. + client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + ip = self.local_scheduler_id_to_ip_map.get(client_id) + if ip: + self.load_metrics.update(ip, static_resources, + dynamic_resources) + else: + print("Warning: could not find ip for client {} in {}.".format( + client_id, self.local_scheduler_id_to_ip_map)) def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. @@ -507,10 +148,8 @@ def _xray_clean_up_entries_for_driver(self, driver_id): task_table_objects = self.state.task_table() driver_id_hex = binary_to_hex(driver_id) driver_task_id_bins = set() - for task_id_hex in task_table_objects: - if len(task_table_objects[task_id_hex]) == 0: - continue - task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"] + for task_id_hex, task_info in task_table_objects.items(): + task_table_object = task_info["TaskSpec"] task_driver_id_hex = task_table_object["DriverID"] if driver_id_hex != task_driver_id_hex: # Ignore tasks that aren't from this driver. @@ -520,9 +159,8 @@ def _xray_clean_up_entries_for_driver(self, driver_id): # Get objects associated with the driver. object_table_objects = self.state.object_table() driver_object_id_bins = set() - for object_id, object_table_object in object_table_objects.items(): - assert len(object_table_object) > 0 - task_id_bin = ray.local_scheduler.compute_task_id(object_id).id() + for object_id, _ in object_table_objects.items(): + task_id_bin = ray.raylet.compute_task_id(object_id).id() if task_id_bin in driver_task_id_bins: driver_object_id_bins.add(object_id.id()) @@ -580,8 +218,7 @@ def process_messages(self, max_messages=10000): max_messages: The maximum number of messages to process before returning. """ - subscribe_clients = ( - [self.primary_subscribe_client] + self.shard_subscribe_clients) + subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = subscribe_client.get_message() @@ -595,22 +232,9 @@ def process_messages(self, max_messages=10000): # Determine the appropriate message handler. message_handler = None - if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL: - # The message was a heartbeat from a plasma manager. - message_handler = self.plasma_manager_heartbeat_handler - elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL: - # The message was a heartbeat from a local scheduler - message_handler = self.local_scheduler_info_handler - elif channel == DB_CLIENT_TABLE_NAME: - # The message was a notification from the db_client table. - message_handler = self.db_client_notification_handler - elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL: - # The message was a notification that a driver was removed. - logger.info("message-handler: driver_removed_handler") - message_handler = self.driver_removed_handler - elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL: + if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: # Similar functionality as local scheduler info channel - message_handler = self.xray_heartbeat_handler + message_handler = self.xray_heartbeat_batch_handler elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL: # Handles driver death. message_handler = self.xray_driver_removed_handler @@ -622,10 +246,7 @@ def process_messages(self, max_messages=10000): message_handler(channel, data) def update_local_scheduler_map(self): - if self.use_raylet: - local_schedulers = self.state.client_table() - else: - local_schedulers = self.state.local_schedulers() + local_schedulers = self.state.client_table() self.local_scheduler_id_to_ip_map = {} for local_scheduler_info in local_schedulers: client_id = local_scheduler_info.get("DBClientID") or \ @@ -673,33 +294,11 @@ def run(self): clients and cleaning up state accordingly. """ # Initialize the subscription channel. - self.subscribe(DB_CLIENT_TABLE_NAME) - self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) - self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL) - self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL) - self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False) + self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL) - # Scan the database table for dead database clients. NOTE: This must be - # called before reading any messages from the subscription channel. - # This ensures that we start in a consistent state, since we may have - # missed notifications that were sent before we connected to the - # subscription channel. - self.scan_db_client_table() - # If there were any dead clients at startup, clean up the associated - # state in the state tables. - if len(self.dead_local_schedulers) > 0: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > 0: - self.cleanup_object_table() - - num_plasma_managers = len(self.live_plasma_managers) + len( - self.dead_plasma_managers) - - logger.debug("{} dead local schedulers, {} plasma managers total, {} " - "dead plasma managers".format( - len(self.dead_local_schedulers), num_plasma_managers, - len(self.dead_plasma_managers))) + # TODO(rkn): If there were any dead clients at startup, we should clean + # up the associated state in the state tables. # Handle messages from the subscription channels. while True: @@ -713,43 +312,9 @@ def run(self): self._maybe_flush_gcs() - # Record how many dead local schedulers and plasma managers we had - # at the beginning of this round. - num_dead_local_schedulers = len(self.dead_local_schedulers) - num_dead_plasma_managers = len(self.dead_plasma_managers) - # Process a round of messages. self.process_messages() - # If any new local schedulers or plasma managers were marked as - # dead in this round, clean up the associated state. - if len(self.dead_local_schedulers) > num_dead_local_schedulers: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > num_dead_plasma_managers: - self.cleanup_object_table() - - # Handle plasma managers that timed out during this round. - plasma_manager_ids = list(self.live_plasma_managers.keys()) - for plasma_manager_id in plasma_manager_ids: - if ((self.live_plasma_managers[plasma_manager_id]) >= - ray._config.num_heartbeats_timeout()): - logger.warn("Timed out {}" - .format(PLASMA_MANAGER_CLIENT_TYPE)) - # Remove the plasma manager from the managers whose - # heartbeats we're tracking. - del self.live_plasma_managers[plasma_manager_id] - # Remove the plasma manager from the db_client table. The - # corresponding state in the object table will be cleaned - # up once we receive the notification for this db_client - # deletion. - self.redis.execute_command("RAY.DISCONNECT", - plasma_manager_id) - - # Increment the number of heartbeats that we've missed from each - # plasma manager. - for plasma_manager_id in self.live_plasma_managers: - self.live_plasma_managers[plasma_manager_id] += 1 - # Wait for a heartbeat interval before processing the next round of # messages. time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3) @@ -773,6 +338,12 @@ def run(self): required=False, type=str, help="the path to the autoscaling config file") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -798,7 +369,11 @@ def run(self): else: autoscaling_config = None - monitor = Monitor(redis_ip_address, redis_port, autoscaling_config) + monitor = Monitor( + redis_ip_address, + redis_port, + autoscaling_config, + redis_password=args.redis_password) try: monitor.run() @@ -810,6 +385,5 @@ def run(self): message = "The monitor failed with the following error:\n{}".format( traceback_str) ray.utils.push_error_to_driver_through_redis( - redis_client, monitor.use_raylet, ray_constants.MONITOR_DIED_ERROR, - message) + redis_client, ray_constants.MONITOR_DIED_ERROR, message) raise e diff --git a/python/ray/plasma/__init__.py b/python/ray/plasma/__init__.py index 1ecd0c2af2dcb..6c6c18b7c555f 100644 --- a/python/ray/plasma/__init__.py +++ b/python/ray/plasma/__init__.py @@ -2,9 +2,6 @@ from __future__ import division from __future__ import print_function -from ray.plasma.plasma import (start_plasma_store, start_plasma_manager, - DEFAULT_PLASMA_STORE_MEMORY) +from ray.plasma.plasma import start_plasma_store, DEFAULT_PLASMA_STORE_MEMORY -__all__ = [ - "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" -] +__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"] diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 60870c2b20210..53b2434260c86 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -3,31 +3,27 @@ from __future__ import print_function import os -import random import subprocess import sys import time -__all__ = [ - "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" -] +from ray.tempfile_services import get_object_store_socket_name + +__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"] PLASMA_WAIT_TIMEOUT = 2**30 DEFAULT_PLASMA_STORE_MEMORY = 10**9 -def random_name(): - return str(random.randint(0, 99999999)) - - def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None, plasma_directory=None, - huge_pages=False): + huge_pages=False, + socket_name=None): """Start a plasma store process. Args: @@ -43,6 +39,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, be created. huge_pages: a boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. + socket_name (str): If provided, it will specify the socket + name used by the plasma store. Return: A tuple of the name of the plasma store socket and the process ID of @@ -66,7 +64,7 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, plasma_store_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store_server") - plasma_store_name = "/tmp/plasma_store{}".format(random_name()) + plasma_store_name = socket_name or get_object_store_socket_name() command = [ plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory) @@ -95,98 +93,3 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) time.sleep(0.1) return plasma_store_name, pid - - -def new_port(): - return random.randint(10000, 65535) - - -def start_plasma_manager(store_name, - redis_address, - node_ip_address="127.0.0.1", - plasma_manager_port=None, - num_retries=20, - use_valgrind=False, - run_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a plasma manager and return the ports it listens on. - - Args: - store_name (str): The name of the plasma store socket. - redis_address (str): The address of the Redis server. - node_ip_address (str): The IP address of the node. - plasma_manager_port (int): The port to use for the plasma manager. If - this is not provided, a port will be generated at random. - use_valgrind (bool): True if the Plasma manager should be started - inside of valgrind and False otherwise. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - - Returns: - A tuple of the Plasma manager socket name, the process ID of the - Plasma manager process, and the port that the manager is - listening on. - - Raises: - Exception: An exception is raised if the manager could not be started. - """ - plasma_manager_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/plasma/plasma_manager") - plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) - if plasma_manager_port is not None: - if num_retries != 1: - raise Exception("num_retries must be 1 if port is specified.") - else: - plasma_manager_port = new_port() - process = None - counter = 0 - while counter < num_retries: - if counter > 0: - print("Plasma manager failed to start, retrying now.") - command = [ - plasma_manager_executable, - "-s", - store_name, - "-m", - plasma_manager_name, - "-h", - node_ip_address, - "-p", - str(plasma_manager_port), - "-r", - redis_address, - ] - if use_valgrind: - process = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - elif run_profiler: - process = subprocess.Popen( - (["valgrind", "--tool=callgrind"] + command), - stdout=stdout_file, - stderr=stderr_file) - else: - process = subprocess.Popen( - command, stdout=stdout_file, stderr=stderr_file) - # This sleep is critical. If the plasma_manager fails to start because - # the port is already in use, then we need it to fail within 0.1 - # seconds. - if use_valgrind: - time.sleep(1) - else: - time.sleep(0.1) - # See if the process has terminated - if process.poll() is None: - return plasma_manager_name, process, plasma_manager_port - # Generate a new port and try again. - plasma_manager_port = new_port() - counter += 1 - raise Exception("Couldn't start plasma manager.") diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py deleted file mode 100644 index a67f2d255e3ac..0000000000000 --- a/python/ray/plasma/test/test.py +++ /dev/null @@ -1,559 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from numpy.testing import assert_equal -import os -import random -import signal -import subprocess -import sys -import threading -import time -import unittest - -# The ray import must come before the pyarrow import because ray modifies the -# python path so that the right version of pyarrow is found. -import ray -from ray.plasma.utils import (random_object_id, create_object_with_id, - create_object) -import ray.ray_constants as ray_constants -from ray import services -import pyarrow as pa -import pyarrow.plasma as plasma - -USE_VALGRIND = False -PLASMA_STORE_MEMORY = 1000000000 - - -def random_name(): - return str(random.randint(0, 99999999)) - - -def assert_get_object_equal(unit_test, - client1, - client2, - object_id, - memory_buffer=None, - metadata=None): - client1_buff = client1.get_buffers([object_id])[0] - client2_buff = client2.get_buffers([object_id])[0] - client1_metadata = client1.get_metadata([object_id])[0] - client2_metadata = client2.get_metadata([object_id])[0] - unit_test.assertEqual(len(client1_buff), len(client2_buff)) - unit_test.assertEqual(len(client1_metadata), len(client2_metadata)) - # Check that the buffers from the two clients are the same. - assert_equal( - np.frombuffer(client1_buff, dtype="uint8"), - np.frombuffer(client2_buff, dtype="uint8")) - # Check that the metadata buffers from the two clients are the same. - assert_equal( - np.frombuffer(client1_metadata, dtype="uint8"), - np.frombuffer(client2_metadata, dtype="uint8")) - # If a reference buffer was provided, check that it is the same as well. - if memory_buffer is not None: - assert_equal( - np.frombuffer(memory_buffer, dtype="uint8"), - np.frombuffer(client1_buff, dtype="uint8")) - # If reference metadata was provided, check that it is the same as well. - if metadata is not None: - assert_equal( - np.frombuffer(metadata, dtype="uint8"), - np.frombuffer(client1_metadata, dtype="uint8")) - - -DEFAULT_PLASMA_STORE_MEMORY = 10**9 - - -def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a plasma store process. - Args: - use_valgrind (bool): True if the plasma store should be started inside - of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the plasma store should be started inside - a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - Return: - A tuple of the name of the plasma store socket and the process ID of - the plasma store process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - plasma_store_executable = os.path.join(pa.__path__[0], - "plasma_store_server") - plasma_store_name = "/tmp/plasma_store{}".format(random_name()) - command = [ - plasma_store_executable, "-s", plasma_store_name, "-m", - str(plasma_store_memory) - ] - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return plasma_store_name, pid - - -# Plasma client tests were moved into arrow - - -class TestPlasmaManager(unittest.TestCase): - def setUp(self): - # Start two PlasmaStores. - store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND) - store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND) - # Start a Redis server. - redis_address, _ = services.start_redis("127.0.0.1") - # Start two PlasmaManagers. - manager_name1, self.p4, self.port1 = ray.plasma.start_plasma_manager( - store_name1, redis_address, use_valgrind=USE_VALGRIND) - manager_name2, self.p5, self.port2 = ray.plasma.start_plasma_manager( - store_name2, redis_address, use_valgrind=USE_VALGRIND) - # Connect two PlasmaClients. - self.client1 = plasma.connect(store_name1, manager_name1, 64) - self.client2 = plasma.connect(store_name2, manager_name2, 64) - - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: If this specific order is changed, valgrind will fail. - self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3] - - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) - - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() - - # Clean up the Redis server. - services.cleanup() - - def test_fetch(self): - for _ in range(10): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - self.client1.fetch([object_id1]) - self.assertEqual(self.client1.contains(object_id1), True) - self.assertEqual(self.client2.contains(object_id1), False) - # Fetch the object from the other plasma manager. - # TODO(rkn): Right now we must wait for the object table to be - # updated. - while not self.client2.contains(object_id1): - self.client2.fetch([object_id1]) - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - - # Test that we can call fetch on object IDs that don't exist yet. - object_id2 = random_object_id() - self.client1.fetch([object_id2]) - self.assertEqual(self.client1.contains(object_id2), False) - memory_buffer2, metadata2 = create_object_with_id( - self.client2, object_id2, 2000, 2000) - # # Check that the object has been fetched. - # self.assertEqual(self.client1.contains(object_id2), True) - # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, object_id2, - # memory_buffer=memory_buffer2, - # metadata=metadata2) - - # Test calling the same fetch request a bunch of times. - object_id3 = random_object_id() - self.assertEqual(self.client1.contains(object_id3), False) - self.assertEqual(self.client2.contains(object_id3), False) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - memory_buffer3, metadata3 = create_object_with_id( - self.client1, object_id3, 2000, 2000) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id3): - self.client2.fetch([object_id3]) - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id3, - memory_buffer=memory_buffer3, - metadata=metadata3) - - def test_fetch_multiple(self): - for _ in range(20): - # Create two objects and a third fake one that doesn't exist. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - missing_object_id = random_object_id() - object_id2, memory_buffer2, metadata2 = create_object( - self.client1, 2000, 2000) - object_ids = [object_id1, missing_object_id, object_id2] - # Fetch the objects from the other plasma store. The second object - # ID should timeout since it does not exist. - # TODO(rkn): Right now we must wait for the object table to be - # updated. - while ((not self.client2.contains(object_id1)) - or (not self.client2.contains(object_id2))): - self.client2.fetch(object_ids) - # Compare the buffers of the objects that do exist. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - # Fetch in the other direction. The fake object still does not - # exist. - self.client1.fetch(object_ids) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - - # Check that we can call fetch with duplicated object IDs. - object_id3 = random_object_id() - self.client1.fetch([object_id3, object_id3]) - object_id4, memory_buffer4, metadata4 = create_object( - self.client1, 2000, 2000) - time.sleep(0.1) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id4): - self.client2.fetch( - [object_id3, object_id3, object_id4, object_id4]) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id4, - memory_buffer=memory_buffer4, - metadata=metadata4) - - def test_wait(self): - # Test timeout. - obj_id0 = random_object_id() - self.client1.wait([obj_id0], timeout=100, num_returns=1) - # If we get here, the test worked. - - # Test wait if local objects available. - obj_id1 = random_object_id() - self.client1.create(obj_id1, 1000) - self.client1.seal(obj_id1) - ready, waiting = self.client1.wait( - [obj_id1], timeout=100, num_returns=1) - self.assertEqual(set(ready), {obj_id1}) - self.assertEqual(waiting, []) - - # Test wait if only one object available and only one object waited - # for. - obj_id2 = random_object_id() - self.client1.create(obj_id2, 1000) - # Don't seal. - ready, waiting = self.client1.wait( - [obj_id2, obj_id1], timeout=100, num_returns=1) - self.assertEqual(set(ready), {obj_id1}) - self.assertEqual(set(waiting), {obj_id2}) - - # Test wait if object is sealed later. - obj_id3 = random_object_id() - - def finish(): - self.client2.create(obj_id3, 1000) - self.client2.seal(obj_id3) - - t = threading.Timer(0.1, finish) - t.start() - ready, waiting = self.client1.wait( - [obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2) - self.assertEqual(set(ready), {obj_id1, obj_id3}) - self.assertEqual(set(waiting), {obj_id2}) - - # Test if the appropriate number of objects is shown if some objects - # are not ready. - ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3) - self.assertEqual(set(ready), {obj_id1, obj_id3}) - self.assertEqual(set(waiting), {obj_id2}) - - # Don't forget to seal obj_id2. - self.client1.seal(obj_id2) - - # Test calling wait a bunch of times. - object_ids = [] - # TODO(rkn): Increasing n to 100 (or larger) will cause failures. The - # problem appears to be that the number of timers added to the manager - # event loop slow down the manager so much that some of the - # asynchronous Redis commands timeout triggering fatal failure - # callbacks. - n = 40 - for i in range(n * (n + 1) // 2): - if i % 2 == 0: - object_id, _, _ = create_object(self.client1, 200, 200) - else: - object_id, _, _ = create_object(self.client2, 200, 200) - object_ids.append(object_id) - # Try waiting for all of the object IDs on the first client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client1.wait( - waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client1.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - # Try waiting for all of the object IDs on the second client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client2.wait( - waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client2.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - # Make sure that wait returns when the requested number of object IDs - # are available and does not wait for all object IDs to be available. - object_ids = [random_object_id() for _ in range(9)] + \ - [plasma.ObjectID(ray_constants.ID_SIZE * b'\x00')] - object_ids_perm = object_ids[:] - random.shuffle(object_ids_perm) - for i in range(10): - if i % 2 == 0: - create_object_with_id(self.client1, object_ids_perm[i], 2000, - 2000) - else: - create_object_with_id(self.client2, object_ids_perm[i], 2000, - 2000) - ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1)) - self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)])) - self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) - - def test_transfer(self): - num_attempts = 100 - for _ in range(100): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - # Transfer the buffer to the the other Plasma store. There is a - # race condition on the create and transfer of the object, so keep - # trying until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client1.transfer("127.0.0.1", self.port2, object_id1) - buff = self.client2.get_buffers( - [object_id1], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff - - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - # # Transfer the buffer again. - # self.client1.transfer("127.0.0.1", self.port2, object_id1) - # # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, - # object_id1, - # memory_buffer=memory_buffer1, - # metadata=metadata1) - - # Create an object. - object_id2, memory_buffer2, metadata2 = create_object( - self.client2, 20000, 20000) - # Transfer the buffer to the the other Plasma store. There is a - # race condition on the create and transfer of the object, so keep - # trying until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client2.transfer("127.0.0.1", self.port1, object_id2) - buff = self.client1.get_buffers( - [object_id2], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff - - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - - def test_illegal_functionality(self): - # Create an object id string. - # object_id = random_object_id() - # Create a new buffer. - # memory_buffer = self.client1.create(object_id, 20000) - # This test is commented out because it currently fails. - # # Transferring the buffer before sealing it should fail. - # self.assertRaises(Exception, - # lambda : self.manager1.transfer(1, object_id)) - pass - - def test_stresstest(self): - a = time.time() - object_ids = [] - for i in range(10000): # TODO(pcm): increase this to 100000. - object_id = random_object_id() - object_ids.append(object_id) - self.client1.create(object_id, 1) - self.client1.seal(object_id) - for object_id in object_ids: - self.client1.transfer("127.0.0.1", self.port2, object_id) - b = time.time() - a - - print("it took", b, "seconds to put and transfer the objects") - - -class TestPlasmaManagerRecovery(unittest.TestCase): - def setUp(self): - # Start a Plasma store. - self.store_name, self.p2 = start_plasma_store( - use_valgrind=USE_VALGRIND) - # Start a Redis server. - self.redis_address, _ = services.start_redis("127.0.0.1") - # Start a PlasmaManagers. - manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - # Connect a PlasmaClient. - self.client = plasma.connect(self.store_name, manager_name, 64) - - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: The plasma managers must be killed before the plasma store - # since plasma store death will bring down the managers. - self.processes_to_kill = [self.p3, self.p2] - - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) - - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() - - # Clean up the Redis server. - services.cleanup() - - def test_delayed_start(self): - num_objects = 10 - # Create some objects using one client. - object_ids = [random_object_id() for _ in range(num_objects)] - for i in range(10): - create_object_with_id(self.client, object_ids[i], 2000, 2000) - - # Wait until the objects have been sealed in the store. - ready, waiting = self.client.wait(object_ids, num_returns=num_objects) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - # Start a second plasma manager attached to the same store. - manager_name, self.p5, self.port2 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - self.processes_to_kill = [self.p5] + self.processes_to_kill - - # Check that the second manager knows about existing objects. - client2 = plasma.connect(self.store_name, manager_name, 64) - ready, waiting = [], object_ids - while True: - ready, waiting = client2.wait( - object_ids, num_returns=num_objects, timeout=0) - if len(ready) == len(object_ids): - break - - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/plasma/utils.py b/python/ray/plasma/utils.py deleted file mode 100644 index 45feb0b1db582..0000000000000 --- a/python/ray/plasma/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import random - -import pyarrow.plasma as plasma -import ray.ray_constants as ray_constants - - -def random_object_id(): - return plasma.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def generate_metadata(length): - metadata_buffer = bytearray(length) - if length > 0: - metadata_buffer[0] = random.randint(0, 255) - metadata_buffer[-1] = random.randint(0, 255) - for _ in range(100): - metadata_buffer[random.randint(0, length - 1)] = (random.randint( - 0, 255)) - return metadata_buffer - - -def write_to_data_buffer(buff, length): - array = np.frombuffer(buff, dtype="uint8") - if length > 0: - array[0] = random.randint(0, 255) - array[-1] = random.randint(0, 255) - for _ in range(100): - array[random.randint(0, length - 1)] = random.randint(0, 255) - - -def create_object_with_id(client, - object_id, - data_size, - metadata_size, - seal=True): - metadata = generate_metadata(metadata_size) - memory_buffer = client.create(object_id, data_size, metadata) - write_to_data_buffer(memory_buffer, data_size) - if seal: - client.seal(object_id) - return memory_buffer, metadata - - -def create_object(client, data_size, metadata_size, seal=True): - object_id = random_object_id() - memory_buffer, metadata = create_object_with_id( - client, object_id, data_size, metadata_size, seal=seal) - return object_id, memory_buffer, metadata diff --git a/python/ray/profiling.py b/python/ray/profiling.py index e4c2d438fc2aa..8cdd8296ed611 100644 --- a/python/ray/profiling.py +++ b/python/ray/profiling.py @@ -59,17 +59,7 @@ def profile(event_type, extra_data=None, worker=None): """ if worker is None: worker = ray.worker.global_worker - if not worker.use_raylet: - # Log the event if this is a worker and not a driver, since the - # driver's event log never gets flushed. - if worker.mode == ray.WORKER_MODE: - return RayLogSpanNonRaylet( - worker.profiler, event_type, contents=extra_data) - else: - return NULL_LOG_SPAN - else: - return RayLogSpanRaylet( - worker.profiler, event_type, extra_data=extra_data) + return RayLogSpanRaylet(worker.profiler, event_type, extra_data=extra_data) class Profiler(object): @@ -124,85 +114,31 @@ def flush_profile_data(self): events = self.events self.events = [] - if not self.worker.use_raylet: - event_log_key = b"event_log:" + self.worker.worker_id - event_log_value = json.dumps(events) - self.worker.local_scheduler_client.log_event( - event_log_key, event_log_value, time.time()) + if self.worker.mode == ray.WORKER_MODE: + component_type = "worker" else: - if self.worker.mode == ray.WORKER_MODE: - component_type = "worker" - else: - component_type = "driver" + component_type = "driver" - self.worker.local_scheduler_client.push_profile_events( - component_type, ray.ObjectID(self.worker.worker_id), - self.worker.node_ip_address, events) + self.worker.local_scheduler_client.push_profile_events( + component_type, ray.ObjectID(self.worker.worker_id), + self.worker.node_ip_address, events) def add_event(self, event): with self.lock: self.events.append(event) -class RayLogSpanNonRaylet(object): - """An object used to enable logging a span of events with a with statement. - - Attributes: - event_type (str): The type of the event being logged. - contents: Additional information to log. - """ - - def __init__(self, profiler, event_type, contents=None): - """Initialize a RayLogSpanNonRaylet object.""" - self.profiler = profiler - self.event_type = event_type - self.contents = contents - - def _log(self, event_type, kind, contents=None): - """Log an event to the global state store. - - This adds the event to a buffer of events locally. The buffer can be - flushed and written to the global state store by calling - flush_profile_data(). +class NoopProfiler(object): + """A no-op profile used when collect_profile_data=False.""" - Args: - event_type (str): The type of the event. - contents: More general data to store with the event. - kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This - is LOG_POINT if the event being logged happens at a single - point in time. It is LOG_SPAN_START if we are starting to log a - span of time, and it is LOG_SPAN_END if we are finishing - logging a span of time. - """ - # TODO(rkn): This code currently takes around half a microsecond. Since - # we call it tens of times per task, this adds up. We will need to redo - # the logging code, perhaps in C. - contents = {} if contents is None else contents - assert isinstance(contents, dict) - # Make sure all of the keys and values in the dictionary are strings. - contents = {str(k): str(v) for k, v in contents.items()} - self.profiler.add_event((time.time(), event_type, kind, contents)) + def start_flush_thread(self): + pass - def __enter__(self): - """Log the beginning of a span event.""" - self._log( - event_type=self.event_type, - contents=self.contents, - kind=LOG_SPAN_START) + def flush_profile_data(self): + pass - def __exit__(self, type, value, tb): - """Log the end of a span event. Log any exception that occurred.""" - if type is None: - self._log(event_type=self.event_type, kind=LOG_SPAN_END) - else: - self._log( - event_type=self.event_type, - contents={ - "type": str(type), - "value": value, - "traceback": traceback.format_exc() - }, - kind=LOG_SPAN_END) + def add_event(self, event): + pass class RayLogSpanRaylet(object): @@ -230,8 +166,9 @@ def set_attribute(self, key, value): value: The attribute value. """ if not isinstance(key, str) or not isinstance(value, str): - raise ValueError("The extra_data argument must be a " - "dictionary mapping strings to strings.") + raise ValueError("The arguments 'key' and 'value' must both be " + "strings. Instead they are {} and {}.".format( + key, value)) self.extra_data[key] = value def __enter__(self): @@ -250,7 +187,8 @@ def __exit__(self, type, value, tb): for key, value in self.extra_data.items(): if not isinstance(key, str) or not isinstance(value, str): raise ValueError("The extra_data argument must be a " - "dictionary mapping strings to strings.") + "dictionary mapping strings to strings. " + "Instead it is {}.".format(self.extra_data)) if type is not None: extra_data = json.dumps({ diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index a9e4519d4cf5d..a1d5e1a765438 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -5,7 +5,7 @@ import os -import ray +from ray.raylet import ObjectID def env_integer(key, default): @@ -15,7 +15,7 @@ def env_integer(key, default): ID_SIZE = 20 -NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\xff") +NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff") # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. @@ -41,7 +41,6 @@ def env_integer(key, default): WORKER_CRASH_PUSH_ERROR = "worker_crash" WORKER_DIED_PUSH_ERROR = "worker_died" PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction" -HASH_MISMATCH_PUSH_ERROR = "object_hash_mismatch" INFEASIBLE_TASK_ERROR = "infeasible_task" REMOVED_NODE_ERROR = "node_removed" MONITOR_DIED_ERROR = "monitor_died" diff --git a/python/ray/local_scheduler/__init__.py b/python/ray/raylet/__init__.py similarity index 76% rename from python/ray/local_scheduler/__init__.py rename to python/ray/raylet/__init__.py index a469776f133b6..8757f59741567 100644 --- a/python/ray/local_scheduler/__init__.py +++ b/python/ray/raylet/__init__.py @@ -2,10 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.core.src.local_scheduler.liblocal_scheduler_library_python import ( +from ray.core.src.ray.raylet.liblocal_scheduler_library_python import ( Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id, task_from_string, task_to_string, _config, common_error) -from .local_scheduler_services import start_local_scheduler __all__ = [ "Task", "LocalSchedulerClient", "ObjectID", "check_simple_value", diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 287d3d045539f..fb2a29e45c512 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -5,6 +5,7 @@ import copy import hashlib import inspect +import logging import ray.ray_constants as ray_constants import ray.signature @@ -14,6 +15,8 @@ DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS = 1 DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0 +logger = logging.getLogger(__name__) + def compute_function_id(function): """Compute an function ID for a function. @@ -22,7 +25,7 @@ def compute_function_id(function): func: The actual function. Returns: - This returns the function ID. + Raw bytes of the function id """ function_id_hash = hashlib.sha1() # Include the function module and name in the hash. @@ -39,8 +42,6 @@ def compute_function_id(function): # Compute the function ID. function_id = function_id_hash.digest() assert len(function_id) == ray_constants.ID_SIZE - function_id = ray.ObjectID(function_id) - return function_id @@ -72,7 +73,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, # TODO(rkn): We store the function ID as a string, so that # RemoteFunction objects can be pickled. We should undo this when # we allow ObjectIDs to be pickled. - self._function_id = compute_function_id(self._function).id() + self._function_id = compute_function_id(function) self._function_name = ( self._function.__module__ + '.' + self._function.__name__) self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS @@ -90,11 +91,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, # # Export the function. worker = ray.worker.get_global_worker() - if worker.mode == ray.worker.SCRIPT_MODE: - self._export() - elif worker.mode is None: - worker.cached_remote_functions_and_actors.append( - ("remote_function", self)) + worker.function_actor_manager.export(self) def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -103,7 +100,7 @@ def __call__(self, *args, **kwargs): def remote(self, *args, **kwargs): """This runs immediately when a remote function is called.""" - return self._submit(args=args, kwargs=kwargs) + return self._remote(args=args, kwargs=kwargs) def _submit(self, args=None, @@ -112,6 +109,23 @@ def _submit(self, num_cpus=None, num_gpus=None, resources=None): + logger.warn( + "WARNING: _submit() is being deprecated. Please use _remote().") + return self._remote( + args=args, + kwargs=kwargs, + num_return_vals=num_return_vals, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=resources) + + def _remote(self, + args=None, + kwargs=None, + num_return_vals=None, + num_cpus=None, + num_gpus=None, + resources=None): """An experimental alternate way to submit remote functions.""" worker = ray.worker.get_global_worker() worker.check_connected() @@ -141,9 +155,3 @@ def _submit(self, return object_ids[0] elif len(object_ids) > 1: return object_ids - - def _export(self): - worker = ray.worker.get_global_worker() - worker.export_remote_function( - ray.ObjectID(self._function_id), self._function_name, - self._function, self._max_calls, self) diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index db9f52687126c..fd6ba3407eaec 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import logging + # Note: do not introduce unnecessary library dependencies here, e.g. gym. # This file is imported from the tune module in order to register RLlib agents. from ray.tune.registry import register_trainable @@ -11,15 +13,26 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv -from ray.rllib.env.serving_env import ServingEnv +from ray.rllib.env.external_env import ExternalEnv from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.sample_batch import SampleBatch +def _setup_logger(): + logger = logging.getLogger("ray.rllib") + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter( + "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" + )) + logger.addHandler(handler) + logger.propagate = False + + def _register_all(): for key in [ - "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "APEX_DDPG", + "PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG", "IMPALA", "ARS", "A2C", "__fake", "__sigmoid_fake_data", "__parameter_tuning" ]: @@ -27,6 +40,7 @@ def _register_all(): register_trainable(key, get_agent_class(key)) +_setup_logger() _register_all() __all__ = [ @@ -37,5 +51,5 @@ def _register_all(): "AsyncVectorEnv", "MultiAgentEnv", "VectorEnv", - "ServingEnv", + "ExternalEnv", ] diff --git a/python/ray/rllib/agents/a3c/a2c.py b/python/ray/rllib/agents/a3c/a2c.py index a792d1d160831..c344592b90863 100644 --- a/python/ray/rllib/agents/a3c/a2c.py +++ b/python/ray/rllib/agents/a3c/a2c.py @@ -4,13 +4,12 @@ from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG as A3C_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources A2C_DEFAULT_CONFIG = merge_dicts( A3C_CONFIG, { - "gpu": False, "sample_batch_size": 20, "min_iter_time_s": 10, "sample_async": False, @@ -24,16 +23,8 @@ class A2CAgent(A3CAgent): _agent_name = "A2C" _default_config = A2C_DEFAULT_CONFIG + @override(A3CAgent) def _make_optimizer(self): return SyncSamplesOptimizer(self.local_evaluator, self.remote_evaluators, self.config["optimizer"]) - - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1, - gpu=cf["gpu_fraction"] if cf["gpu"] else 0, - extra_cpu=cf["num_workers"], - extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index afda9506248d3..43daa0b3ef781 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -7,9 +7,10 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer -from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources +from ray.rllib.utils.annotations import override +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # Size of rollout batch "sample_batch_size": 10, @@ -27,37 +28,14 @@ "vf_loss_coeff": 0.5, # Entropy coefficient "entropy_coeff": -0.01, - # Whether to place workers on GPUs - "use_gpu_for_workers": False, # Min time per iteration "min_iter_time_s": 5, # Workers sample async. Note that this increases the effective # sample_batch_size by up to 5x due to async buffering of batches. "sample_async": True, - # Model and preprocessor options - "model": { - # Use LSTM model. Requires TF. - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - # (Image statespace) - Converts image to Channels = 1 - "grayscale": True, - # (Image statespace) - Each pixel - "zero_mean": False, - # (Image statespace) - Converts image to (dim, dim, C) - "dim": 84, - # (Image statespace) - Converts image shape to (C, dim, dim) - "channel_major": False, - }, - # Configure TF for single-process operation - "tf_session_args": { - "intra_op_parallelism_threads": 1, - "inter_op_parallelism_threads": 1, - "gpu_options": { - "allow_growth": True, - }, - }, }) +# __sphinx_doc_end__ +# yapf: enable class A3CAgent(Agent): @@ -67,15 +45,7 @@ class A3CAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = A3CPolicyGraph - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1, - gpu=0, - extra_cpu=cf["num_workers"], - extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) - + @override(Agent) def _init(self): if self.config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ @@ -87,21 +57,22 @@ def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, policy_cls) self.remote_evaluators = self.make_remote_evaluators( - self.env_creator, policy_cls, self.config["num_workers"], - {"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0}) + self.env_creator, policy_cls, self.config["num_workers"]) self.optimizer = self._make_optimizer() - def _make_optimizer(self): - return AsyncGradientsOptimizer(self.local_evaluator, - self.remote_evaluators, - self.config["optimizer"]) - + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() while time.time() - start < self.config["min_iter_time_s"]: self.optimizer.step() - result = self.optimizer.collect_metrics() + result = self.optimizer.collect_metrics( + self.config["collect_metrics_timeout"]) result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result + + def _make_optimizer(self): + return AsyncGradientsOptimizer(self.local_evaluator, + self.remote_evaluators, + self.config["optimizer"]) diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index b2298acc9edb6..50258f58ac3aa 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -10,11 +10,12 @@ import ray from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule -from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override class A3CLoss(object): @@ -49,12 +50,16 @@ def __init__(self, observation_space, action_space, config): tf.float32, [None] + list(observation_space.shape)) dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) - self.model = ModelCatalog.get_model(self.observations, logit_dim, - self.config["model"]) + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") + self.model = ModelCatalog.get_model({ + "obs": self.observations, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + }, observation_space, logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) - self.vf = tf.reshape( - linear(self.model.last_layer, 1, "value", normc_initializer(1.0)), - [-1]) + self.vf = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) @@ -78,6 +83,8 @@ def __init__(self, observation_space, action_space, config): loss_in = [ ("obs", self.observations), ("actions", actions), + ("prev_actions", prev_actions), + ("prev_rewards", prev_rewards), ("advantages", advantages), ("value_targets", self.v_target), ] @@ -90,10 +97,12 @@ def __init__(self, observation_space, action_space, config): self.sess, obs_input=self.observations, action_sampler=action_dist.sample(), - loss=self.loss.total_loss, + loss=self.model.loss() + self.loss.total_loss, loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, seq_lens=self.model.seq_lens, max_seq_len=self.config["model"]["max_seq_len"]) @@ -111,31 +120,15 @@ def __init__(self, observation_space, action_space, config): self.sess.run(tf.global_variables_initializer()) - def extra_compute_action_fetches(self): - return {"vf_preds": self.vf} - - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.vf, feed_dict) - return vf[0] - - def gradients(self, optimizer): - grads = tf.gradients(self.loss.total_loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - def extra_compute_grad_fetches(self): - return self.stats_fetches - + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): completed = sample_batch["dones"][-1] if completed: last_r = 0.0 @@ -143,6 +136,30 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): next_state = [] for i in range(len(self.model.state_in)): next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) + last_r = self._value(sample_batch["new_obs"][-1], *next_state) return compute_advantages(sample_batch, last_r, self.config["gamma"], self.config["lambda"]) + + @override(TFPolicyGraph) + def gradients(self, optimizer): + grads = tf.gradients(self.loss.total_loss, self.var_list) + self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) + clipped_grads = list(zip(self.grads, self.var_list)) + return clipped_grads + + @override(TFPolicyGraph) + def extra_compute_grad_fetches(self): + return self.stats_fetches + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return {"vf_preds": self.vf} + + def _value(self, ob, *args): + feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self.sess.run(self.vf, feed_dict) + return vf[0] diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py index dcdada591a053..c24340d8d10a0 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py @@ -10,7 +10,9 @@ from ray.rllib.models.pytorch.misc import var_to_np from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph +from ray.rllib.utils.annotations import override class A3CLoss(nn.Module): @@ -56,13 +58,19 @@ def __init__(self, obs_space, action_space, config): loss, loss_inputs=["obs", "actions", "advantages", "value_targets"]) + @override(TorchPolicyGraph) def extra_action_out(self, model_out): return {"vf_preds": var_to_np(model_out[1])} + @override(TorchPolicyGraph) def optimizer(self): return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): completed = sample_batch["dones"][-1] if completed: last_r = 0.0 diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 030ae64248d87..8e6797eede032 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -2,32 +2,92 @@ from __future__ import division from __future__ import print_function +from datetime import datetime import copy -import json +import logging import os import pickle +import six import tempfile -from datetime import datetime import tensorflow as tf import ray +from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable +from ray.tune.trial import Resources from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR +logger = logging.getLogger(__name__) + +# yapf: disable +# __sphinx_doc_begin__ COMMON_CONFIG = { + # === Debugging === + # Whether to write episode stats and videos to the agent log dir + "monitor": False, + # Set the ray.rllib.* log level for the agent process and its evaluators + "log_level": "INFO", + # Callbacks that will be run during various phases of training. These all + # take a single "info" dict as an argument. For episode callbacks, custom + # metrics can be attached to the episode by updating the episode object's + # custom metrics dict (see examples/custom_metrics_and_callbacks.py). + "callbacks": { + "on_episode_start": None, # arg: {"env": .., "episode": ...} + "on_episode_step": None, # arg: {"env": .., "episode": ...} + "on_episode_end": None, # arg: {"env": .., "episode": ...} + "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_train_result": None, # arg: {"agent": ..., "result": ...} + }, + + # === Policy === + # Arguments to pass to model. See models/catalog.py for a full list of the + # available model options. + "model": MODEL_DEFAULTS, + # Arguments to pass to the policy optimizer. These vary by optimizer. + "optimizer": {}, + + # === Environment === # Discount factor of the MDP "gamma": 0.99, - # Number of steps after which the rollout gets cut + # Number of steps after which the episode is forced to terminate "horizon": None, - # Number of environments to evaluate vectorwise per worker. - "num_envs_per_worker": 1, + # Arguments to pass to the env creator + "env_config": {}, + # Environment name can also be passed via config + "env": None, + # Whether to clip rewards prior to experience postprocessing. Setting to + # None means clip for Atari only. + "clip_rewards": None, + # Whether to np.clip() actions to the action space low/high range spec. + "clip_actions": True, + # Whether to use rllib or deepmind preprocessors by default + "preprocessor_pref": "deepmind", + + # === Resources === # Number of actors used for parallelism "num_workers": 2, + # Number of GPUs to allocate to the driver. Note that not all algorithms + # can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs). + "num_gpus": 0, + # Number of CPUs to allocate per worker. + "num_cpus_per_worker": 1, + # Number of GPUs to allocate per worker. This can be fractional. + "num_gpus_per_worker": 0, + # Any custom resources to allocate per worker. + "custom_resources_per_worker": {}, + # Number of CPUs to allocate for the driver. Note: this only takes effect + # when running in Tune. + "num_cpus_for_driver": 1, + + # === Execution === + # Number of environments to evaluate vectorwise per worker. + "num_envs_per_worker": 1, # Default sample batch size "sample_batch_size": 200, # Training batch size, if applicable. Should be >= sample_batch_size. @@ -37,30 +97,15 @@ "batch_mode": "truncate_episodes", # Whether to use a background thread for sampling (slightly off-policy) "sample_async": False, - # Which observation filter to apply to the observation + # Element-wise observation filter, either "NoFilter" or "MeanStdFilter" "observation_filter": "NoFilter", # Whether to synchronize the statistics of remote filters. "synchronize_filters": True, - # Whether to clip rewards prior to experience postprocessing. Setting to - # None means clip for Atari only. - "clip_rewards": None, - # Whether to use rllib or deepmind preprocessors - "preprocessor_pref": "deepmind", - # Arguments to pass to the env creator - "env_config": {}, - # Environment name can also be passed via config - "env": None, - # Arguments to pass to model - "model": { - "use_lstm": False, - "max_seq_len": 20, - }, - # Arguments to pass to the rllib optimizer - "optimizer": {}, # Configure TF for single-process operation by default "tf_session_args": { - "intra_op_parallelism_threads": 1, - "inter_op_parallelism_threads": 1, + # note: overriden by `local_evaluator_tf_session_args` + "intra_op_parallelism_threads": 2, + "inter_op_parallelism_threads": 2, "gpu_options": { "allow_growth": True, }, @@ -70,12 +115,17 @@ }, "allow_soft_placement": True, # required by PPO multi-gpu }, + # Override the following tf session args on the local evaluator + "local_evaluator_tf_session_args": { + # Allow a higher level of parallelism by default, but not unlimited + # since that can cause crashes with many concurrent drivers. + "intra_op_parallelism_threads": 8, + "inter_op_parallelism_threads": 8, + }, # Whether to LZ4 compress observations "compress_observations": False, - # Whether to write episode stats and videos to the agent log dir - "monitor": False, - # Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs) - "gpu_fraction": 1, + # Drop metric batches from unresponsive workers after this many seconds + "collect_metrics_timeout": 180, # === Multiagent === "multiagent": { @@ -88,6 +138,8 @@ "policies_to_train": None, }, } +# __sphinx_doc_end__ +# yapf: enable def with_common_config(extra_config): @@ -115,68 +167,6 @@ class Agent(Trainable): "tf_session_args", "env_config", "model", "optimizer", "multiagent" ] - def make_local_evaluator(self, env_creator, policy_graph): - """Convenience method to return configured local evaluator.""" - - return self._make_evaluator( - PolicyEvaluator, - env_creator, - policy_graph, - 0, - # important: allow local tf to use multiple CPUs for optimization - merge_dicts( - self.config, { - "tf_session_args": { - "intra_op_parallelism_threads": None, - "inter_op_parallelism_threads": None, - } - })) - - def make_remote_evaluators(self, env_creator, policy_graph, count, - remote_args): - """Convenience method to return a number of remote evaluators.""" - - cls = PolicyEvaluator.as_remote(**remote_args).remote - return [ - self._make_evaluator(cls, env_creator, policy_graph, i + 1, - self.config) for i in range(count) - ] - - def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, - config): - def session_creator(): - return tf.Session( - config=tf.ConfigProto(**config["tf_session_args"])) - - return cls( - env_creator, - self.config["multiagent"]["policy_graphs"] or policy_graph, - policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], - policies_to_train=self.config["multiagent"]["policies_to_train"], - tf_session_creator=(session_creator - if config["tf_session_args"] else None), - batch_steps=config["sample_batch_size"], - batch_mode=config["batch_mode"], - episode_horizon=config["horizon"], - preprocessor_pref=config["preprocessor_pref"], - sample_async=config["sample_async"], - compress_observations=config["compress_observations"], - num_envs=config["num_envs_per_worker"], - observation_filter=config["observation_filter"], - clip_rewards=config["clip_rewards"], - env_config=config["env_config"], - model_config=config["model"], - policy_config=config, - worker_index=worker_index, - monitor_path=self.logdir if config["monitor"] else None) - - @classmethod - def resource_help(cls, config): - return ("\n\nYou can adjust the resource requests of RLlib agents by " - "setting `num_workers` and other configs. See the " - "DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: " + json.dumps(config)) - def __init__(self, config=None, env=None, logger_creator=None): """Initialize an RLLib agent. @@ -189,17 +179,19 @@ def __init__(self, config=None, env=None, logger_creator=None): """ config = config or {} + Agent._validate_config(config) # Vars to synchronize to evaluators on each train call self.global_vars = {"timestep": 0} # Agents allow env ids to be passed directly to the constructor. - self._env_id = env or config.get("env") + self._env_id = _register_if_needed(env or config.get("env")) # Create a default logger creator if no logger_creator is specified if logger_creator is None: timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - logdir_prefix = '_'.join([self._agent_name, self._env_id, timestr]) + logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id, + timestr) def default_logger_creator(config): """Creates a Unified logger with a default logdir prefix @@ -215,6 +207,19 @@ def default_logger_creator(config): Trainable.__init__(self, config, logger_creator) + @classmethod + @override(Trainable) + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + Agent._validate_config(cf) + # TODO(ekl): add custom resources here once tune supports them + return Resources( + cpu=cf["num_cpus_for_driver"], + gpu=cf["num_gpus"], + extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + + @override(Trainable) def train(self): """Overrides super.train to synchronize global vars.""" @@ -224,6 +229,7 @@ def train(self): self.optimizer.local_evaluator.set_global_vars(self.global_vars) for ev in self.optimizer.remote_evaluators: ev.set_global_vars.remote(self.global_vars) + logger.debug("updated global vars: {}".format(self.global_vars)) if (self.config.get("observation_filter", "NoFilter") != "NoFilter" and hasattr(self, "local_evaluator")): @@ -231,13 +237,22 @@ def train(self): self.local_evaluator.filters, self.remote_evaluators, update_remote=self.config["synchronize_filters"]) - - return Trainable.train(self) - - def _setup(self): + logger.debug("synchronized filters: {}".format( + self.local_evaluator.filters)) + + result = Trainable.train(self) + if self.config["callbacks"].get("on_train_result"): + self.config["callbacks"]["on_train_result"]({ + "agent": self, + "result": result, + }) + return result + + @override(Trainable) + def _setup(self, config): env = self._env_id if env: - self.config["env"] = env + config["env"] = env if _global_registry.contains(ENV_CREATOR, env): self.env_creator = _global_registry.get(ENV_CREATOR, env) else: @@ -247,36 +262,41 @@ def _setup(self): self.env_creator = lambda env_config: None # Merge the supplied config with the class default - merged_config = self._default_config.copy() - merged_config = deep_update(merged_config, self.config, + merged_config = copy.deepcopy(self._default_config) + merged_config = deep_update(merged_config, config, self._allow_unknown_configs, self._allow_unknown_subkeys) self.config = merged_config + if self.config.get("log_level"): + logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) # TODO(ekl) setting the graph is unnecessary for PyTorch agents with tf.Graph().as_default(): self._init() - def _init(self): - """Subclasses should override this for custom initialization.""" - - raise NotImplementedError - - @property - def iteration(self): - """Current training iter, auto-incremented with each train() call.""" - - return self._iteration + @override(Trainable) + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + if hasattr(self, "remote_evaluators"): + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() + if hasattr(self, "optimizer"): + self.optimizer.stop() - @property - def _agent_name(self): - """Subclasses should override this to declare their name.""" + @override(Trainable) + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + return checkpoint_path - raise NotImplementedError + @override(Trainable) + def _restore(self, checkpoint_path): + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) - @property - def _default_config(self): - """Subclasses should override this to declare their default config.""" + def _init(self): + """Subclasses should override this for custom initialization.""" raise NotImplementedError @@ -299,13 +319,29 @@ def compute_action(self, observation, state=None, policy_id="default"): observation, update=False) if state: return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False), + lambda p: p.compute_single_action(filtered_obs, state), policy_id=policy_id) return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False)[0], - policy_id=policy_id) + lambda p: p.compute_single_action(filtered_obs, state)[0], + policy_id=policy_id) + + @property + def iteration(self): + """Current training iter, auto-incremented with each train() call.""" + + return self._iteration + + @property + def _agent_name(self): + """Subclasses should override this to declare their name.""" + + raise NotImplementedError + + @property + def _default_config(self): + """Subclasses should override this to declare their default config.""" + + raise NotImplementedError def get_weights(self, policies=None): """Return a dictionary of policy ids to weights. @@ -324,11 +360,89 @@ def set_weights(self, weights): """ self.local_evaluator.set_weights(weights) - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() + def make_local_evaluator(self, env_creator, policy_graph): + """Convenience method to return configured local evaluator.""" + + return self._make_evaluator( + PolicyEvaluator, + env_creator, + policy_graph, + 0, + # important: allow local tf to use more CPUs for optimization + merge_dicts(self.config, { + "tf_session_args": self. + config["local_evaluator_tf_session_args"] + })) + + def make_remote_evaluators(self, env_creator, policy_graph, count): + """Convenience method to return a number of remote evaluators.""" + + remote_args = { + "num_cpus": self.config["num_cpus_per_worker"], + "num_gpus": self.config["num_gpus_per_worker"], + "resources": self.config["custom_resources_per_worker"], + } + + cls = PolicyEvaluator.as_remote(**remote_args).remote + return [ + self._make_evaluator(cls, env_creator, policy_graph, i + 1, + self.config) for i in range(count) + ] + + def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, + config): + def session_creator(): + logger.debug("Creating TF session {}".format( + config["tf_session_args"])) + return tf.Session( + config=tf.ConfigProto(**config["tf_session_args"])) + + return cls( + env_creator, + self.config["multiagent"]["policy_graphs"] or policy_graph, + policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], + policies_to_train=self.config["multiagent"]["policies_to_train"], + tf_session_creator=(session_creator + if config["tf_session_args"] else None), + batch_steps=config["sample_batch_size"], + batch_mode=config["batch_mode"], + episode_horizon=config["horizon"], + preprocessor_pref=config["preprocessor_pref"], + sample_async=config["sample_async"], + compress_observations=config["compress_observations"], + num_envs=config["num_envs_per_worker"], + observation_filter=config["observation_filter"], + clip_rewards=config["clip_rewards"], + clip_actions=config["clip_actions"], + env_config=config["env_config"], + model_config=config["model"], + policy_config=config, + worker_index=worker_index, + monitor_path=self.logdir if config["monitor"] else None, + log_level=config["log_level"], + callbacks=config["callbacks"]) + + @classmethod + def resource_help(cls, config): + return ("\n\nYou can adjust the resource requests of RLlib agents by " + "setting `num_workers` and other configs. See the " + "DEFAULT_CONFIG defined by each agent for more info.\n\n" + "The config of this agent is: {}".format(config)) + + @staticmethod + def _validate_config(config): + if "gpu" in config: + raise ValueError( + "The `gpu` config is deprecated, please use `num_gpus=0|1` " + "instead.") + if "gpu_fraction" in config: + raise ValueError( + "The `gpu_fraction` config is deprecated, please use " + "`num_gpus=` instead.") + if "use_gpu_for_workers" in config: + raise ValueError( + "The `use_gpu_for_workers` config is deprecated, please use " + "`num_gpus_per_worker=1` instead.") def __getstate__(self): state = {} @@ -347,16 +461,14 @@ def __setstate__(self, state): if "optimizer" in state: self.optimizer.restore(state["optimizer"]) - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), - open(checkpoint_path + ".agent_state", "wb")) - return checkpoint_path - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb")) - self.__setstate__(extra_data) +def _register_if_needed(env_object): + if isinstance(env_object, six.string_types): + return env_object + elif isinstance(env_object, type): + name = env_object.__name__ + register_env(name, lambda config: env_object(config)) + return name def get_agent_class(alg): @@ -389,9 +501,6 @@ def get_agent_class(alg): elif alg == "A2C": from ray.rllib.agents import a3c return a3c.A2CAgent - elif alg == "BC": - from ray.rllib.agents import bc - return bc.BCAgent elif alg == "PG": from ray.rllib.agents import pg return pg.PGAgent diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index e1a9459857716..1b39a79d0c13b 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -7,38 +7,42 @@ from __future__ import print_function from collections import namedtuple +import logging import numpy as np import time import ray from ray.rllib.agents import Agent, with_common_config -from ray.tune.trial import Resources from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies -from ray.rllib.agents.es import tabular_logger as tlogger from ray.rllib.agents.ars import utils +from ray.rllib.utils.annotations import override +from ray.rllib.utils import FilterManager + +logger = logging.getLogger(__name__) Result = namedtuple("Result", [ "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", "eval_returns", "eval_lengths" ]) +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ - 'noise_stdev': 0.02, # std deviation of parameter noise - 'num_rollouts': 32, # number of perturbs to try - 'rollouts_used': 32, # number of perturbs to keep in gradient estimate - 'num_workers': 2, - 'sgd_stepsize': 0.01, # sgd step-size - 'observation_filter': "MeanStdFilter", - 'noise_size': 250000000, - 'eval_prob': 0.03, # probability of evaluating the parameter rewards - 'report_length': 10, # how many of the last rewards we average over - 'env_config': {}, - 'offset': 0, - 'policy_type': "LinearPolicy", # ["LinearPolicy", "MLPPolicy"] - "fcnet_hiddens": [32, 32], # fcnet structure of MLPPolicy + "noise_stdev": 0.02, # std deviation of parameter noise + "num_rollouts": 32, # number of perturbs to try + "rollouts_used": 32, # number of perturbs to keep in gradient estimate + "num_workers": 2, + "sgd_stepsize": 0.01, # sgd step-size + "observation_filter": "MeanStdFilter", + "noise_size": 250000000, + "eval_prob": 0.03, # probability of evaluating the parameter rewards + "report_length": 10, # how many of the last rewards we average over + "offset": 0, }) +# __sphinx_doc_end__ +# yapf: enable @ray.remote @@ -67,15 +71,9 @@ def get_delta(self, dim): @ray.remote class Worker(object): - def __init__(self, - config, - policy_params, - env_creator, - noise, - min_task_runtime=0.2): + def __init__(self, config, env_creator, noise, min_task_runtime=0.2): self.min_task_runtime = min_task_runtime self.config = config - self.policy_params = policy_params self.noise = SharedNoiseTable(noise) self.env = env_creator(config["env_config"]) @@ -83,15 +81,25 @@ def __init__(self, self.preprocessor = models.ModelCatalog.get_preprocessor(self.env) self.sess = utils.make_session(single_threaded=True) - if config["policy_type"] == "LinearPolicy": - self.policy = policies.LinearPolicy( - self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], **policy_params) - else: - self.policy = policies.MLPPolicy( - self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], config["fcnet_hiddens"], - **policy_params) + self.policy = policies.GenericPolicy( + self.sess, self.env.action_space, self.env.observation_space, + self.preprocessor, config["observation_filter"], config["model"]) + + @property + def filters(self): + return {"default": self.policy.get_filter()} + + def sync_filters(self, new_filters): + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters def rollout(self, timestep_limit, add_noise=False): rollout_rewards, rollout_length = policies.rollout( @@ -154,31 +162,16 @@ class ARSAgent(Agent): _agent_name = "ARS" _default_config = DEFAULT_CONFIG - @classmethod - def default_resource_request(cls, config): - cf = dict(cls._default_config, **config) - return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"]) - + @override(Agent) def _init(self): - policy_params = {"action_noise_std": 0.0} - - # register the linear network - utils.register_linear_network() - env = self.env_creator(self.config["env_config"]) from ray.rllib import models preprocessor = models.ModelCatalog.get_preprocessor(env) self.sess = utils.make_session(single_threaded=False) - if self.config["policy_type"] == "LinearPolicy": - self.policy = policies.LinearPolicy( - self.sess, env.action_space, preprocessor, - self.config["observation_filter"], **policy_params) - else: - self.policy = policies.MLPPolicy( - self.sess, env.action_space, preprocessor, - self.config["observation_filter"], - self.config["fcnet_hiddens"], **policy_params) + self.policy = policies.GenericPolicy( + self.sess, env.action_space, env.observation_space, preprocessor, + self.config["observation_filter"], self.config["model"]) self.optimizer = optimizers.SGD(self.policy, self.config["sgd_stepsize"]) @@ -187,41 +180,22 @@ def _init(self): self.report_length = self.config["report_length"] # Create the shared noise table. - print("Creating shared noise table.") + logger.info("Creating shared noise table.") noise_id = create_shared_noise.remote(self.config["noise_size"]) self.noise = SharedNoiseTable(ray.get(noise_id)) # Create the actors. - print("Creating actors.") + logger.info("Creating actors.") self.workers = [ - Worker.remote(self.config, policy_params, self.env_creator, - noise_id) for _ in range(self.config["num_workers"]) + Worker.remote(self.config, self.env_creator, noise_id) + for _ in range(self.config["num_workers"]) ] self.episodes_so_far = 0 self.reward_list = [] self.tstart = time.time() - def _collect_results(self, theta_id, min_episodes): - num_episodes, num_timesteps = 0, 0 - results = [] - while num_episodes < min_episodes: - print("Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) - rollout_ids = [ - worker.do_rollouts.remote(theta_id) for worker in self.workers - ] - # Get the results of the rollouts. - for result in ray.get(rollout_ids): - results.append(result) - # Update the number of episodes and the number of timesteps - # keeping in mind that result.noisy_lengths is a list of lists, - # where the inner lists have length 2. - num_episodes += sum(len(pair) for pair in result.noisy_lengths) - num_timesteps += sum( - sum(pair) for pair in result.noisy_lengths) - return results, num_episodes, num_timesteps - + @override(Agent) def _train(self): config = self.config @@ -287,7 +261,6 @@ def _train(self): g /= np.std(noisy_returns) assert (g.shape == (self.policy.num_params, ) and g.dtype == np.float32) - print('the number of policy params is, ', self.policy.num_params) # Compute the new weights theta. theta, update_ratio = self.optimizer.update(-g) # Set the new weights in the local copy of the policy. @@ -296,18 +269,14 @@ def _train(self): if len(all_eval_returns) > 0: self.reward_list.append(eval_returns.mean()) - tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean()) - tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std()) - tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean()) - - tlogger.record_tabular("WeightsNorm", float(np.square(theta).sum())) - tlogger.record_tabular("WeightsStd", float(np.std(theta))) - tlogger.record_tabular("Grad2Norm", float(np.sqrt(np.square(g).sum()))) - tlogger.record_tabular("UpdateRatio", float(update_ratio)) - tlogger.dump_tabular() + # Now sync the filters + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) info = { "weights_norm": np.square(theta).sum(), + "weights_std": np.std(theta), "grad_norm": np.square(g).sum(), "update_ratio": update_ratio, "episodes_this_iter": noisy_lengths.size, @@ -322,20 +291,49 @@ def _train(self): return result + @override(Agent) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() + @override(Agent) + def compute_action(self, observation): + return self.policy.compute(observation, update=True)[0] + + def _collect_results(self, theta_id, min_episodes): + num_episodes, num_timesteps = 0, 0 + results = [] + while num_episodes < min_episodes: + logger.info( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) + rollout_ids = [ + worker.do_rollouts.remote(theta_id) for worker in self.workers + ] + # Get the results of the rollouts. + for result in ray.get(rollout_ids): + results.append(result) + # Update the number of episodes and the number of timesteps + # keeping in mind that result.noisy_lengths is a list of lists, + # where the inner lists have length 2. + num_episodes += sum(len(pair) for pair in result.noisy_lengths) + num_timesteps += sum( + sum(pair) for pair in result.noisy_lengths) + + return results, num_episodes, num_timesteps + def __getstate__(self): return { "weights": self.policy.get_weights(), + "filter": self.policy.get_filter(), "episodes_so_far": self.episodes_so_far, } def __setstate__(self, state): - self.policy.set_weights(state["weights"]) self.episodes_so_far = state["episodes_so_far"] - - def compute_action(self, observation): - return self.policy.compute(observation, update=True)[0] + self.policy.set_weights(state["weights"]) + self.policy.set_filter(state["filter"]) + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index 3a25d68eb6b3e..27f664655f423 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -10,8 +10,8 @@ import tensorflow as tf import ray +from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.utils.filter import get_filter -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.models import ModelCatalog @@ -57,16 +57,11 @@ class GenericPolicy(object): def __init__(self, sess, action_space, + obs_space, preprocessor, observation_filter, - action_noise_std, - options={}): - - if len(preprocessor.shape) > 1: - raise UnsupportedSpaceException( - "Observation space {} is not supported with ARS.".format( - preprocessor.shape)) - + model_config, + action_noise_std=0.0): self.sess = sess self.action_space = action_space self.action_noise_std = action_noise_std @@ -78,9 +73,11 @@ def __init__(self, # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( - action_space, dist_type="deterministic") + action_space, model_config, dist_type="deterministic") - model = ModelCatalog.get_model(self.inputs, dist_dim, options=options) + model = ModelCatalog.get_model({ + "obs": self.inputs + }, obs_space, dist_dim, model_config) dist = dist_class(model.outputs) self.sampler = dist.sample() @@ -97,6 +94,7 @@ def compute(self, observation, add_noise=False, update=True): observation = self.observation_filter(observation[None], update=update) action = self.sess.run( self.sampler, feed_dict={self.inputs: observation}) + action = _unbatch_tuple_actions(action) if add_noise and isinstance(self.action_space, gym.spaces.Box): action += np.random.randn(*action.shape) * self.action_noise_std return action @@ -104,33 +102,11 @@ def compute(self, observation, add_noise=False, update=True): def set_weights(self, x): self.variables.set_flat(x) - def get_weights(self): - return self.variables.get_flat() + def set_filter(self, obs_filter): + self.observation_filter = obs_filter + def get_filter(self): + return self.observation_filter -class LinearPolicy(GenericPolicy): - def __init__(self, sess, action_space, preprocessor, observation_filter, - action_noise_std): - options = {"custom_model": "LinearNetwork"} - GenericPolicy.__init__( - self, - sess, - action_space, - preprocessor, - observation_filter, - action_noise_std, - options=options) - - -class MLPPolicy(GenericPolicy): - def __init__(self, sess, action_space, preprocessor, observation_filter, - fcnet_hiddens, action_noise_std): - options = {"fcnet_hiddens": fcnet_hiddens} - GenericPolicy.__init__( - self, - sess, - action_space, - preprocessor, - observation_filter, - action_noise_std, - options=options) + def get_weights(self): + return self.variables.get_flat() diff --git a/python/ray/rllib/agents/ars/utils.py b/python/ray/rllib/agents/ars/utils.py index a70dd97bb61a3..1575e46c38370 100644 --- a/python/ray/rllib/agents/ars/utils.py +++ b/python/ray/rllib/agents/ars/utils.py @@ -7,9 +7,6 @@ import numpy as np import tensorflow as tf -from ray.rllib.models import ModelCatalog, Model -import tensorflow.contrib.slim as slim -from ray.rllib.models.misc import normc_initializer def compute_ranks(x): @@ -62,21 +59,3 @@ def batched_weighted_sum(weights, vecs, batch_size): np.asarray(batch_vecs, dtype=np.float32)) num_items_summed += len(batch_weights) return total, num_items_summed - - -class LinearNetwork(Model): - """Generic linear network.""" - - def _build_layers(self, inputs, num_outputs, _): - with tf.name_scope("linear"): - output = slim.fully_connected( - inputs, - num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - ) - return output, inputs - - -def register_linear_network(): - ModelCatalog.register_custom_model("LinearNetwork", LinearNetwork) diff --git a/python/ray/rllib/agents/bc/__init__.py b/python/ray/rllib/agents/bc/__init__.py deleted file mode 100644 index eb0f8dc2d7dd3..0000000000000 --- a/python/ray/rllib/agents/bc/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ray.rllib.agents.bc.bc import BCAgent, DEFAULT_CONFIG - -__all__ = ["BCAgent", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/bc/bc.py b/python/ray/rllib/agents/bc/bc.py deleted file mode 100644 index b2552bf990f56..0000000000000 --- a/python/ray/rllib/agents/bc/bc.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.rllib.agents.agent import Agent -from ray.rllib.agents.bc.bc_evaluator import BCEvaluator, \ - GPURemoteBCEvaluator, RemoteBCEvaluator -from ray.rllib.optimizers import AsyncGradientsOptimizer -from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources - -DEFAULT_CONFIG = { - # Number of workers (excluding master) - "num_workers": 1, - # Size of rollout batch - "batch_size": 100, - # Max global norm for each gradient calculated by worker - "grad_clip": 40.0, - # Learning rate - "lr": 0.0001, - # Whether to use a GPU for local optimization. - "gpu": False, - # Whether to place workers on GPUs - "use_gpu_for_workers": False, - # Model and preprocessor options - "model": { - # (Image statespace) - Converts image to Channels = 1 - "grayscale": True, - # (Image statespace) - Each pixel - "zero_mean": False, - # (Image statespace) - Converts image to (dim, dim, C) - "dim": 84, - # (Image statespace) - Converts image shape to (C, dim, dim) - "channel_major": False - }, - # Arguments to pass to the rllib optimizer - "optimizer": { - # Number of gradients applied for each `train` step - "grads_per_step": 100, - }, - # Arguments to pass to the env creator - "env_config": {}, -} - - -class BCAgent(Agent): - _agent_name = "BC" - _default_config = DEFAULT_CONFIG - _allow_unknown_configs = True - - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - if cf["use_gpu_for_workers"]: - num_gpus_per_worker = cf["gpu_fraction"] - else: - num_gpus_per_worker = 0 - return Resources( - cpu=1, - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, - extra_cpu=cf["num_workers"], - extra_gpu=num_gpus_per_worker * cf["num_workers"]) - - def _init(self): - self.local_evaluator = BCEvaluator(self.env_creator, self.config, - self.logdir) - if self.config["use_gpu_for_workers"]: - remote_cls = GPURemoteBCEvaluator - else: - remote_cls = RemoteBCEvaluator - self.remote_evaluators = [ - remote_cls.remote(self.env_creator, self.config, self.logdir) - for _ in range(self.config["num_workers"]) - ] - self.optimizer = AsyncGradientsOptimizer(self.local_evaluator, - self.remote_evaluators, - self.config["optimizer"]) - - def _train(self): - self.optimizer.step() - metric_lists = [ - re.get_metrics.remote() for re in self.remote_evaluators - ] - total_samples = 0 - total_loss = 0 - for metrics in metric_lists: - for m in ray.get(metrics): - total_samples += m["num_samples"] - total_loss += m["loss"] - result = dict( - mean_loss=total_loss / total_samples, - timesteps_this_iter=total_samples, - ) - return result - - def compute_action(self, observation): - action, info = self.local_evaluator.policy.compute(observation) - return action diff --git a/python/ray/rllib/agents/bc/bc_evaluator.py b/python/ray/rllib/agents/bc/bc_evaluator.py deleted file mode 100644 index 4726b4a3cf176..0000000000000 --- a/python/ray/rllib/agents/bc/bc_evaluator.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import pickle -from six.moves import queue - -import ray -from ray.rllib.agents.bc.experience_dataset import ExperienceDataset -from ray.rllib.agents.bc.policy import BCPolicy -from ray.rllib.evaluation.interface import EvaluatorInterface -from ray.rllib.models import ModelCatalog - - -class BCEvaluator(EvaluatorInterface): - def __init__(self, env_creator, config, logdir): - env = ModelCatalog.get_preprocessor_as_wrapper( - env_creator(config["env_config"]), config["model"]) - self.dataset = ExperienceDataset(config["dataset_path"]) - self.policy = BCPolicy(env.observation_space, env.action_space, config) - self.config = config - self.logdir = logdir - self.metrics_queue = queue.Queue() - - def sample(self): - return self.dataset.sample(self.config["batch_size"]) - - def compute_gradients(self, samples): - gradient, info = self.policy.compute_gradients(samples) - self.metrics_queue.put({ - "num_samples": info["num_samples"], - "loss": info["loss"] - }) - return gradient, {} - - def apply_gradients(self, grads): - self.policy.apply_gradients(grads) - - def get_weights(self): - return self.policy.get_weights() - - def set_weights(self, params): - self.policy.set_weights(params) - - def save(self): - weights = self.get_weights() - return pickle.dumps({"weights": weights}) - - def restore(self, objs): - objs = pickle.loads(objs) - self.set_weights(objs["weights"]) - - def get_metrics(self): - completed = [] - while True: - try: - completed.append(self.metrics_queue.get_nowait()) - except queue.Empty: - break - return completed - - -RemoteBCEvaluator = ray.remote(BCEvaluator) -GPURemoteBCEvaluator = ray.remote(num_gpus=1)(BCEvaluator) diff --git a/python/ray/rllib/agents/bc/experience_dataset.py b/python/ray/rllib/agents/bc/experience_dataset.py deleted file mode 100644 index d082841842698..0000000000000 --- a/python/ray/rllib/agents/bc/experience_dataset.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools -import pickle - -import numpy as np - - -class ExperienceDataset(object): - def __init__(self, dataset_path): - """Create dataset of experience to imitate. - - Parameters - ---------- - dataset_path: - Path of file containing the database as pickled list of trajectories, - each trajectory being a list of steps, - each step containing the observation and action as its first two - elements. - The file must be available on each machine used by a BCEvaluator. - """ - self._dataset = list( - itertools.chain.from_iterable( - pickle.load(open(dataset_path, "rb")))) - - def sample(self, batch_size): - indexes = np.random.choice(len(self._dataset), batch_size) - samples = { - 'observations': [self._dataset[i][0] for i in indexes], - 'actions': [self._dataset[i][1] for i in indexes] - } - return samples diff --git a/python/ray/rllib/agents/bc/policy.py b/python/ray/rllib/agents/bc/policy.py deleted file mode 100644 index a504e3ec64ff8..0000000000000 --- a/python/ray/rllib/agents/bc/policy.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -import gym - -import ray -from ray.rllib.models.catalog import ModelCatalog - - -class BCPolicy(object): - def __init__(self, obs_space, action_space, config): - self.local_steps = 0 - self.config = config - self.summarize = config.get("summarize") - self._setup_graph(obs_space, action_space) - self.setup_loss(action_space) - self.setup_gradients() - self.initialize() - - def _setup_graph(self, obs_space, ac_space): - self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - ac_space, self.config["model"]) - self._model = ModelCatalog.get_model(self.x, self.logit_dim, - self.config["model"]) - self.logits = self._model.outputs - self.curr_dist = dist_class(self.logits) - self.sample = self.curr_dist.sample() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - - def setup_loss(self, action_space): - if isinstance(action_space, gym.spaces.Box): - self.ac = tf.placeholder( - tf.float32, [None] + list(action_space.shape), name="ac") - elif isinstance(action_space, gym.spaces.Discrete): - self.ac = tf.placeholder(tf.int64, [None], name="ac") - else: - raise NotImplementedError("action space" + - str(type(action_space)) + - "currently not supported") - log_prob = self.curr_dist.logp(self.ac) - self.pi_loss = -tf.reduce_sum(log_prob) - self.loss = self.pi_loss - - def setup_gradients(self): - grads = tf.gradients(self.loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - grads_and_vars = list(zip(self.grads, self.var_list)) - opt = tf.train.AdamOptimizer(self.config["lr"]) - self._apply_gradients = opt.apply_gradients(grads_and_vars) - - def initialize(self): - if self.summarize: - bs = tf.to_float(tf.shape(self.x)[0]) - tf.summary.scalar("model/policy_loss", self.pi_loss / bs) - tf.summary.scalar("model/grad_gnorm", tf.global_norm(self.grads)) - tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list)) - self.summary_op = tf.summary.merge_all() - - # TODO(rliaw): Can consider exposing these parameters - self.sess = tf.Session( - graph=self.g, - config=tf.ConfigProto( - intra_op_parallelism_threads=1, - inter_op_parallelism_threads=2, - gpu_options=tf.GPUOptions(allow_growth=True))) - self.variables = ray.experimental.TensorFlowVariables( - self.loss, self.sess) - self.sess.run(tf.global_variables_initializer()) - - def compute_gradients(self, samples): - info = {} - feed_dict = { - self.x: samples["observations"], - self.ac: samples["actions"] - } - self.grads = [g for g in self.grads if g is not None] - self.local_steps += 1 - if self.summarize: - loss, grad, summ = self.sess.run( - [self.loss, self.grads, self.summary_op], feed_dict=feed_dict) - info["summary"] = summ - else: - loss, grad = self.sess.run( - [self.loss, self.grads], feed_dict=feed_dict) - info["num_samples"] = len(samples) - info["loss"] = loss - return grad, info - - def apply_gradients(self, grads): - feed_dict = {self.grads[i]: grads[i] for i in range(len(grads))} - self.sess.run(self._apply_gradients, feed_dict=feed_dict) - - def get_weights(self): - weights = self.variables.get_weights() - return weights - - def set_weights(self, weights): - self.variables.set_weights(weights) - - def compute(self, ob, *args): - action = self.sess.run(self.sample, {self.x: [ob]}) - return action, None diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index c2276d0a9a556..6b3465013da36 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -3,11 +3,11 @@ from __future__ import print_function from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG +from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources APEX_DDPG_DEFAULT_CONFIG = merge_dicts( - DDPG_CONFIG, + DDPG_CONFIG, # see also the options in ddpg.py, which are also supported { "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( @@ -17,7 +17,7 @@ "debug": False }), "n_step": 3, - "gpu": False, + "num_gpus": 0, "num_workers": 32, "buffer_size": 2000000, "learning_starts": 50000, @@ -43,15 +43,7 @@ class ApexDDPGAgent(DDPGAgent): _agent_name = "APEX_DDPG" _default_config = APEX_DDPG_DEFAULT_CONFIG - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1 + cf["optimizer"]["num_replay_buffer_shards"], - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - + @override(DDPGAgent) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index b475e297a2472..ca0e8087f4644 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -5,6 +5,7 @@ from ray.rllib.agents.agent import with_common_config from ray.rllib.agents.dqn.dqn import DQNAgent from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule OPTIMIZER_SHARED_CONFIGS = [ @@ -13,7 +14,25 @@ "train_batch_size", "learning_starts" ] +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ + # === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks === + # TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html + # twin Q-net + "twin_q": False, + # delayed policy update + "policy_delay": 1, + # target policy smoothing + # this also forces the use of gaussian instead of OU noise for exploration + "smooth_target_policy": False, + # gaussian stddev of act noise + "act_noise": 0.1, + # gaussian stddev of target noise + "target_noise": 0.2, + # target noise limit (bound) + "noise_clip": 0.5, + # === Model === # Hidden layer sizes of the policy network "actor_hiddens": [64, 64], @@ -65,9 +84,11 @@ "compress_observations": False, # === Optimization === - # Learning rate for adam optimizer - "actor_lr": 1e-4, - "critic_lr": 1e-3, + # Learning rate for adam optimizer. + # Instead of using two optimizers, we use two different loss coefficients + "lr": 1e-3, + "actor_loss_coeff": 0.1, + "critic_loss_coeff": 1.0, # If True, use huber loss instead of squared loss for critic network # Conventionally, no need to clip gradients if using a huber loss "use_huber": False, @@ -88,16 +109,10 @@ "train_batch_size": 256, # === Parallelism === - # Whether to use a GPU for local optimization. - "gpu": False, # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Whether to allocate GPUs for workers (if > 0). - "num_gpus_per_worker": 0, - # Whether to allocate CPUs for workers (if > 0). - "num_cpus_per_worker": 1, # Optimizer class to use. "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. @@ -107,6 +122,8 @@ # Prevent iterations from going lower than this time span "min_iter_time_s": 1, }) +# __sphinx_doc_end__ +# yapf: enable class DDPGAgent(DQNAgent): @@ -115,14 +132,22 @@ class DDPGAgent(DQNAgent): _default_config = DEFAULT_CONFIG _policy_graph = DDPGPolicyGraph + @override(DQNAgent) def _make_exploration_schedule(self, worker_index): # Override DQN's schedule to take into account `noise_scale` if self.config["per_worker_exploration"]: assert self.config["num_workers"] > 1, \ "This requires multiple workers" - exponent = ( - 1 + worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(self.config["noise_scale"] * 0.4**exponent) + if worker_index >= 0: + exponent = ( + 1 + + worker_index / float(self.config["num_workers"] - 1) * 7) + return ConstantSchedule( + self.config["noise_scale"] * 0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) else: return LinearSchedule( schedule_timesteps=int(self.config["exploration_fraction"] * diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index a6f26885fe308..b8b625734793d 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -11,7 +11,9 @@ from ray.rllib.agents.dqn.dqn_policy_graph import _huber_loss, \ _minimize_and_clip, _scope_vars, _postprocess_dqn from ray.rllib.models import ModelCatalog +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph A_SCOPE = "a_func" @@ -19,6 +21,8 @@ P_TARGET_SCOPE = "target_p_func" Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" +TWIN_Q_SCOPE = "twin_q_func" +TWIN_Q_TARGET_SCOPE = "twin_target_q_func" class PNetwork(object): @@ -50,24 +54,47 @@ def __init__(self, stochastic, eps, theta=0.15, - sigma=0.2): + sigma=0.2, + use_gaussian_noise=False, + act_noise=0.1, + is_target=False, + target_noise=0.2, + noise_clip=0.5): # shape is [None, dim_action] deterministic_actions = ( (high_action - low_action) * p_values + low_action) - exploration_sample = tf.get_variable( - name="ornstein_uhlenbeck", - dtype=tf.float32, - initializer=low_action.size * [.0], - trainable=False) - normal_sample = tf.random_normal( - shape=[low_action.size], mean=0.0, stddev=1.0) - exploration_value = tf.assign_add( - exploration_sample, - theta * (.0 - exploration_sample) + sigma * normal_sample) - stochastic_actions = deterministic_actions + eps * ( - high_action - low_action) * exploration_value + if use_gaussian_noise: + if is_target: + normal_sample = tf.random_normal( + tf.shape(deterministic_actions), stddev=target_noise) + normal_sample = tf.clip_by_value(normal_sample, -noise_clip, + noise_clip) + stochastic_actions = tf.clip_by_value( + deterministic_actions + normal_sample, low_action, + high_action) + else: + normal_sample = tf.random_normal( + tf.shape(deterministic_actions), stddev=act_noise) + stochastic_actions = tf.clip_by_value( + deterministic_actions + normal_sample, low_action, + high_action) + else: + exploration_sample = tf.get_variable( + name="ornstein_uhlenbeck", + dtype=tf.float32, + initializer=low_action.size * [.0], + trainable=False) + normal_sample = tf.random_normal( + shape=[low_action.size], mean=0.0, stddev=1.0) + exploration_value = tf.assign_add( + exploration_sample, + theta * (.0 - exploration_sample) + sigma * normal_sample) + stochastic_actions = tf.clip_by_value( + deterministic_actions + + eps * (high_action - low_action) * exploration_value, + low_action, high_action) self.actions = tf.cond(stochastic, lambda: stochastic_actions, lambda: deterministic_actions) @@ -86,6 +113,7 @@ def __init__(self, q_out, num_outputs=hidden, activation_fn=activation) self.value = layers.fully_connected( q_out, num_outputs=1, activation_fn=None) + self.model = model class ActorCriticLoss(object): @@ -96,12 +124,21 @@ def __init__(self, importance_weights, rewards, done_mask, + twin_q_t, + twin_q_tp1, + actor_loss_coeff=0.1, + critic_loss_coeff=1.0, gamma=0.99, n_step=1, use_huber=False, - huber_threshold=1.0): + huber_threshold=1.0, + twin_q=False, + policy_delay=1): q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) + if twin_q: + twin_q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) + q_tp1 = tf.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best @@ -110,16 +147,36 @@ def __init__(self, q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked # compute the error (potentially clipped) - self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) - if use_huber: - errors = _huber_loss(self.td_error, huber_threshold) + if twin_q: + td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) + twin_td_error = twin_q_t_selected - tf.stop_gradient( + q_t_selected_target) + self.td_error = td_error + twin_td_error + if use_huber: + errors = _huber_loss(td_error, huber_threshold) + _huber_loss( + twin_td_error, huber_threshold) + else: + errors = 0.5 * tf.square(td_error) + 0.5 * tf.square( + twin_td_error) else: - errors = 0.5 * tf.square(self.td_error) - - self.critic_loss = tf.reduce_mean(importance_weights * errors) + self.td_error = ( + q_t_selected - tf.stop_gradient(q_t_selected_target)) + if use_huber: + errors = _huber_loss(self.td_error, huber_threshold) + else: + errors = 0.5 * tf.square(self.td_error) + + self.critic_loss = critic_loss_coeff * tf.reduce_mean( + importance_weights * errors) + + # for policy gradient, update policy net one time v.s. + # update critic net `policy_delay` time(s) + global_step = tf.train.get_or_create_global_step() + policy_delay_mask = tf.to_float( + tf.equal(tf.mod(global_step, policy_delay), 0)) + self.actor_loss = (-1.0 * actor_loss_coeff * policy_delay_mask * + tf.reduce_mean(q_tp0)) - # for policy gradient - self.actor_loss = -1.0 * tf.reduce_mean(q_tp0) self.total_loss = self.actor_loss + self.critic_loss @@ -136,20 +193,22 @@ def __init__(self, observation_space, action_space, config): self.dim_actions = action_space.shape[0] self.low_action = action_space.low self.high_action = action_space.high - self.actor_optimizer = tf.train.AdamOptimizer( - learning_rate=config["actor_lr"]) - self.critic_optimizer = tf.train.AdamOptimizer( - learning_rate=config["critic_lr"]) + + # create global step for counting the number of update operations + self.global_step = tf.train.get_or_create_global_step() # Action inputs self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") self.cur_observations = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) + tf.float32, + shape=(None, ) + observation_space.shape, + name="cur_obs") # Actor: P (policy) network with tf.variable_scope(P_SCOPE) as scope: - p_values = self._build_p_network(self.cur_observations) + p_values = self._build_p_network(self.cur_observations, + observation_space) self.p_func_vars = _scope_vars(scope.name) # Action outputs @@ -157,10 +216,13 @@ def __init__(self, observation_space, action_space, config): self.output_actions = self._build_action_network( p_values, self.stochastic, self.eps) - with tf.variable_scope(A_SCOPE, reuse=True): - exploration_sample = tf.get_variable(name="ornstein_uhlenbeck") - self.reset_noise_op = tf.assign(exploration_sample, - self.dim_actions * [.0]) + if self.config["smooth_target_policy"]: + self.reset_noise_op = tf.no_op() + else: + with tf.variable_scope(A_SCOPE, reuse=True): + exploration_sample = tf.get_variable(name="ornstein_uhlenbeck") + self.reset_noise_op = tf.assign(exploration_sample, + self.dim_actions * [.0]) # Replay inputs self.obs_t = tf.placeholder( @@ -178,37 +240,63 @@ def __init__(self, observation_space, action_space, config): # p network evaluation with tf.variable_scope(P_SCOPE, reuse=True) as scope: - self.p_t = self._build_p_network(self.obs_t) + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + self.p_t = self._build_p_network(self.obs_t, observation_space) + p_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - + prev_update_ops) # target p network evaluation with tf.variable_scope(P_TARGET_SCOPE) as scope: - p_tp1 = self._build_p_network(self.obs_tp1) + p_tp1 = self._build_p_network(self.obs_tp1, observation_space) target_p_func_vars = _scope_vars(scope.name) # Action outputs with tf.variable_scope(A_SCOPE, reuse=True): - deterministic_flag = tf.constant(value=False, dtype=tf.bool) - zero_eps = tf.constant(value=.0, dtype=tf.float32) output_actions = self._build_action_network( - self.p_t, deterministic_flag, zero_eps) - + self.p_t, + stochastic=tf.constant(value=False, dtype=tf.bool), + eps=.0) output_actions_estimated = self._build_action_network( - p_tp1, deterministic_flag, zero_eps) + p_tp1, + stochastic=tf.constant( + value=self.config["smooth_target_policy"], dtype=tf.bool), + eps=.0, + is_target=True) # q network evaluation + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) with tf.variable_scope(Q_SCOPE) as scope: - q_t = self._build_q_network(self.obs_t, self.act_t) + q_t, model = self._build_q_network(self.obs_t, observation_space, + self.act_t) self.q_func_vars = _scope_vars(scope.name) with tf.variable_scope(Q_SCOPE, reuse=True): - q_tp0 = self._build_q_network(self.obs_t, output_actions) + q_tp0, _ = self._build_q_network(self.obs_t, observation_space, + output_actions) + if self.config["twin_q"]: + with tf.variable_scope(TWIN_Q_SCOPE) as scope: + twin_q_t, twin_model = self._build_q_network( + self.obs_t, observation_space, self.act_t) + self.twin_q_func_vars = _scope_vars(scope.name) + q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: - q_tp1 = self._build_q_network(self.obs_tp1, - output_actions_estimated) + q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space, + output_actions_estimated) target_q_func_vars = _scope_vars(scope.name) - - self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0) + if self.config["twin_q"]: + with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope: + twin_q_tp1, _ = self._build_q_network( + self.obs_tp1, observation_space, output_actions_estimated) + twin_target_q_func_vars = _scope_vars(scope.name) + + if self.config["twin_q"]: + self.loss = self._build_actor_critic_loss( + q_t, q_tp1, q_tp0, twin_q_t=twin_q_t, twin_q_tp1=twin_q_tp1) + else: + self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0) if config["l2_reg"] is not None: for var in self.p_func_vars: @@ -219,6 +307,11 @@ def __init__(self, observation_space, action_space, config): if "bias" not in var.name: self.loss.critic_loss += ( config["l2_reg"] * 0.5 * tf.nn.l2_loss(var)) + if self.config["twin_q"]: + for var in self.twin_q_func_vars: + if "bias" not in var.name: + self.loss.critic_loss += ( + config["l2_reg"] * 0.5 * tf.nn.l2_loss(var)) # update_target_fn will be called periodically to copy Q network to # target Q network @@ -231,6 +324,13 @@ def __init__(self, observation_space, action_space, config): update_target_expr.append( var_target.assign(self.tau * var + (1.0 - self.tau) * var_target)) + if self.config["twin_q"]: + for var, var_target in zip( + sorted(self.twin_q_func_vars, key=lambda v: v.name), + sorted(twin_target_q_func_vars, key=lambda v: v.name)): + update_target_expr.append( + var_target.assign(self.tau * var + + (1.0 - self.tau) * var_target)) for var, var_target in zip( sorted(self.p_func_vars, key=lambda v: v.name), sorted(target_p_func_vars, key=lambda v: v.name)): @@ -255,8 +355,9 @@ def __init__(self, observation_space, action_space, config): self.sess, obs_input=self.cur_observations, action_sampler=self.output_actions, - loss=self.loss.total_loss, - loss_inputs=self.loss_inputs) + loss=model.loss() + self.loss.total_loss, + loss_inputs=self.loss_inputs, + update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) # Note that this encompasses both the policy and Q-value networks and @@ -267,46 +368,31 @@ def __init__(self, observation_space, action_space, config): # Hard initial update self.update_target(tau=1.0) - def _build_q_network(self, obs, actions): - return QNetwork( - ModelCatalog.get_model(obs, 1, self.config["model"]), actions, - self.config["critic_hiddens"], - self.config["critic_hidden_activation"]).value - - def _build_p_network(self, obs): - return PNetwork( - ModelCatalog.get_model(obs, 1, self.config["model"]), - self.dim_actions, self.config["actor_hiddens"], - self.config["actor_hidden_activation"]).action_scores - - def _build_action_network(self, p_values, stochastic, eps): - return ActionNetwork(p_values, self.low_action, self.high_action, - stochastic, eps, self.config["exploration_theta"], - self.config["exploration_sigma"]).actions - - def _build_actor_critic_loss(self, q_t, q_tp1, q_tp0): - return ActorCriticLoss( - q_t, q_tp1, q_tp0, self.importance_weights, self.rew_t, - self.done_mask, self.config["gamma"], self.config["n_step"], - self.config["use_huber"], self.config["huber_threshold"]) + @override(TFPolicyGraph) + def optimizer(self): + return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) + @override(TFPolicyGraph) def gradients(self, optimizer): if self.config["grad_norm_clipping"] is not None: actor_grads_and_vars = _minimize_and_clip( - self.actor_optimizer, + optimizer, self.loss.actor_loss, var_list=self.p_func_vars, clip_val=self.config["grad_norm_clipping"]) critic_grads_and_vars = _minimize_and_clip( - self.critic_optimizer, + optimizer, self.loss.critic_loss, - var_list=self.q_func_vars, + var_list=self.q_func_vars + self.twin_q_func_vars + if self.config["twin_q"] else self.q_func_vars, clip_val=self.config["grad_norm_clipping"]) else: - actor_grads_and_vars = self.actor_optimizer.compute_gradients( + actor_grads_and_vars = optimizer.compute_gradients( self.loss.actor_loss, var_list=self.p_func_vars) - critic_grads_and_vars = self.critic_optimizer.compute_gradients( - self.loss.critic_loss, var_list=self.q_func_vars) + critic_grads_and_vars = optimizer.compute_gradients( + self.loss.critic_loss, + var_list=self.q_func_vars + self.twin_q_func_vars + if self.config["twin_q"] else self.q_func_vars) actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars if g is not None] critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars @@ -314,20 +400,85 @@ def gradients(self, optimizer): grads_and_vars = actor_grads_and_vars + critic_grads_and_vars return grads_and_vars + @override(TFPolicyGraph) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, } - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): return _postprocess_dqn(self, sample_batch) + @override(TFPolicyGraph) + def get_weights(self): + return self.variables.get_weights() + + @override(TFPolicyGraph) + def set_weights(self, weights): + self.variables.set_weights(weights) + + @override(PolicyGraph) + def get_state(self): + return [TFPolicyGraph.get_state(self), self.cur_epsilon] + + @override(PolicyGraph) + def set_state(self, state): + TFPolicyGraph.set_state(self, state[0]) + self.set_epsilon(state[1]) + + def _build_q_network(self, obs, obs_space, actions): + q_net = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, obs_space, 1, self.config["model"]), actions, + self.config["critic_hiddens"], + self.config["critic_hidden_activation"]) + return q_net.value, q_net.model + + def _build_p_network(self, obs, obs_space): + return PNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, obs_space, 1, self.config["model"]), self.dim_actions, + self.config["actor_hiddens"], + self.config["actor_hidden_activation"]).action_scores + + def _build_action_network(self, p_values, stochastic, eps, + is_target=False): + return ActionNetwork( + p_values, self.low_action, self.high_action, stochastic, eps, + self.config["exploration_theta"], self.config["exploration_sigma"], + self.config["smooth_target_policy"], self.config["act_noise"], + is_target, self.config["target_noise"], + self.config["noise_clip"]).actions + + def _build_actor_critic_loss(self, + q_t, + q_tp1, + q_tp0, + twin_q_t=None, + twin_q_tp1=None): + return ActorCriticLoss( + q_t, q_tp1, q_tp0, self.importance_weights, self.rew_t, + self.done_mask, twin_q_t, twin_q_tp1, + self.config["actor_loss_coeff"], self.config["critic_loss_coeff"], + self.config["gamma"], self.config["n_step"], + self.config["use_huber"], self.config["huber_threshold"], + self.config["twin_q"]) + def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): td_err = self.sess.run( @@ -353,16 +504,3 @@ def update_target(self, tau=None): def set_epsilon(self, epsilon): self.cur_epsilon = epsilon - - def get_weights(self): - return self.variables.get_weights() - - def set_weights(self, weights): - self.variables.set_weights(weights) - - def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] - - def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) - self.set_epsilon(state[1]) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index e6058b41f9af3..c9b15e0eca792 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -4,10 +4,12 @@ from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources +from ray.rllib.utils.annotations import override +# yapf: disable +# __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( - DQN_CONFIG, + DQN_CONFIG, # see also the options in dqn.py, which are also supported { "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( @@ -17,7 +19,7 @@ "debug": False }), "n_step": 3, - "gpu": True, + "num_gpus": 1, "num_workers": 32, "buffer_size": 2000000, "learning_starts": 50000, @@ -30,6 +32,8 @@ "min_iter_time_s": 30, }, ) +# __sphinx_doc_end__ +# yapf: enable class ApexAgent(DQNAgent): @@ -42,15 +46,7 @@ class ApexAgent(DQNAgent): _agent_name = "APEX" _default_config = APEX_DEFAULT_CONFIG - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1 + cf["optimizer"]["num_replay_buffer_shards"], - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - + @override(DQNAgent) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/dqn/common/wrappers.py b/python/ray/rllib/agents/dqn/common/wrappers.py deleted file mode 100644 index eb6a6c0d5b5c2..0000000000000 --- a/python/ray/rllib/agents/dqn/common/wrappers.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from ray.rllib.models import ModelCatalog -from ray.rllib.env.atari_wrappers import wrap_deepmind - - -def wrap_dqn(env, options): - """Apply a common set of wrappers for DQN.""" - - is_atari = hasattr(env.unwrapped, "ale") - - # Override atari default to use the deepmind wrappers. - # TODO(ekl) this logic should be pushed to the catalog. - if is_atari and "custom_preprocessor" not in options: - return wrap_deepmind(env, dim=options.get("dim", 84)) - - return ModelCatalog.get_preprocessor_as_wrapper(env, options) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index c945cdbc9fe8e..10e3edd48bd82 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -7,10 +7,8 @@ from ray.rllib import optimizers from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph -from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule -from ray.tune.trial import Resources OPTIMIZER_SHARED_CONFIGS = [ "buffer_size", "prioritized_replay", "prioritized_replay_alpha", @@ -20,6 +18,8 @@ "learning_starts" ] +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Model === # Number of atoms for representing the distribution of return. When @@ -40,8 +40,6 @@ "hiddens": [256], # N-step Q learning "n_step": 1, - # Whether to use rllib or deepmind preprocessors - "preprocessor_pref": "deepmind", # === Exploration === # Max num timesteps for annealing schedules. Exploration is annealed from @@ -96,16 +94,10 @@ "train_batch_size": 32, # === Parallelism === - # Whether to use a GPU for local optimization. - "gpu": False, # Number of workers for collecting samples with. This only makes sense # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Whether to allocate GPUs for workers (if > 0). - "num_gpus_per_worker": 0, - # Whether to allocate CPUs for workers (if > 0). - "num_cpus_per_worker": 1, # Optimizer class to use. "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. @@ -115,6 +107,8 @@ # Prevent iterations from going lower than this time span "min_iter_time_s": 1, }) +# __sphinx_doc_end__ +# yapf: enable class DQNAgent(Agent): @@ -124,22 +118,14 @@ class DQNAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1, - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - + @override(Agent) def _init(self): # Update effective batch size to include n-step - adjusted_batch_size = ( - self.config["sample_batch_size"] + self.config["n_step"] - 1) + adjusted_batch_size = max(self.config["sample_batch_size"], + self.config["n_step"]) self.config["sample_batch_size"] = adjusted_batch_size - self.exploration0 = self._make_exploration_schedule(0) + self.exploration0 = self._make_exploration_schedule(-1) self.explorations = [ self._make_exploration_schedule(i) for i in range(self.config["num_workers"]) @@ -159,12 +145,9 @@ def _init(self): self.env_creator, self._policy_graph) def create_remote_evaluators(): - return self.make_remote_evaluators( - self.env_creator, self._policy_graph, - self.config["num_workers"], { - "num_cpus": self.config["num_cpus_per_worker"], - "num_gpus": self.config["num_gpus_per_worker"] - }) + return self.make_remote_evaluators(self.env_creator, + self._policy_graph, + self.config["num_workers"]) if self.config["optimizer_class"] != "AsyncReplayOptimizer": self.remote_evaluators = create_remote_evaluators() @@ -178,47 +161,16 @@ def create_remote_evaluators(): # Create the remote evaluators *after* the replay actors if self.remote_evaluators is None: self.remote_evaluators = create_remote_evaluators() - self.optimizer.set_evaluators(self.remote_evaluators) + self.optimizer._set_evaluators(self.remote_evaluators) self.last_target_update_ts = 0 self.num_target_updates = 0 - def _make_exploration_schedule(self, worker_index): - # Use either a different `eps` per worker, or a linear schedule. - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - exponent = ( - 1 + worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_eps"]) - - @property - def global_timestep(self): - return self.optimizer.num_steps_sampled - - def update_target_if_needed(self): - if self.global_timestep - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.global_timestep - self.num_target_updates += 1 - + @override(Agent) def _train(self): start_timestep = self.global_timestep - start = time.time() - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"] - ) or time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - self.update_target_if_needed() - + # Update worker explorations exp_vals = [self.exploration0.value(self.global_timestep)] self.local_evaluator.foreach_trainable_policy( lambda p, _: p.set_epsilon(exp_vals[0])) @@ -228,14 +180,23 @@ def _train(self): lambda p, _: p.set_epsilon(exp_val)) exp_vals.append(exp_val) + # Do optimization steps + start = time.time() + while (self.global_timestep - start_timestep < + self.config["timesteps_per_iteration"] + ) or time.time() - start < self.config["min_iter_time_s"]: + self.optimizer.step() + self.update_target_if_needed() + if self.config["per_worker_exploration"]: # Only collect metrics from the third of workers with lowest eps - result = collect_metrics( - self.local_evaluator, - self.remote_evaluators[-len(self.remote_evaluators) // 3:]) + result = self.optimizer.collect_metrics( + timeout_seconds=self.config["collect_metrics_timeout"], + selected_evaluators=self.remote_evaluators[ + -len(self.remote_evaluators) // 3:]) else: - result = collect_metrics(self.local_evaluator, - self.remote_evaluators) + result = self.optimizer.collect_metrics( + timeout_seconds=self.config["collect_metrics_timeout"]) result.update( timesteps_this_iter=self.global_timestep - start_timestep, @@ -246,6 +207,38 @@ def _train(self): }, **self.optimizer.stats())) return result + def update_target_if_needed(self): + if self.global_timestep - self.last_target_update_ts > \ + self.config["target_network_update_freq"]: + self.local_evaluator.foreach_trainable_policy( + lambda p, _: p.update_target()) + self.last_target_update_ts = self.global_timestep + self.num_target_updates += 1 + + @property + def global_timestep(self): + return self.optimizer.num_steps_sampled + + def _make_exploration_schedule(self, worker_index): + # Use either a different `eps` per worker, or a linear schedule. + if self.config["per_worker_exploration"]: + assert self.config["num_workers"] > 1, \ + "This requires multiple workers" + if worker_index >= 0: + exponent = ( + 1 + + worker_index / float(self.config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + return LinearSchedule( + schedule_timesteps=int(self.config["exploration_fraction"] * + self.config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=self.config["exploration_final_eps"]) + def __getstate__(self): state = Agent.__getstate__(self) state.update({ diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index a2e8c8022c7b1..625e577fff164 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -10,7 +10,9 @@ import ray from ray.rllib.models import ModelCatalog from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph Q_SCOPE = "q_func" @@ -28,17 +30,23 @@ def __init__(self, v_min=-10.0, v_max=10.0, sigma0=0.5): + self.model = model with tf.variable_scope("action_value"): - action_out = model.last_layer - for i in range(len(hiddens)): - if use_noisy: - action_out = self.noisy_layer("hidden_%d" % i, action_out, - hiddens[i], sigma0) - else: - action_out = layers.fully_connected( - action_out, - num_outputs=hiddens[i], - activation_fn=tf.nn.relu) + if hiddens: + action_out = model.last_layer + for i in range(len(hiddens)): + if use_noisy: + action_out = self.noisy_layer( + "hidden_%d" % i, action_out, hiddens[i], sigma0) + else: + action_out = layers.fully_connected( + action_out, + num_outputs=hiddens[i], + activation_fn=tf.nn.relu) + else: + # Avoid postprocessing the outputs. This enables custom models + # to be used for parametric action DQN. + action_out = model.outputs if use_noisy: action_scores = self.noisy_layer( "output", @@ -46,11 +54,13 @@ def __init__(self, num_actions * num_atoms, sigma0, non_linear=False) - else: + elif hiddens: action_scores = layers.fully_connected( action_out, num_outputs=num_actions * num_atoms, activation_fn=None) + else: + action_scores = model.outputs if num_atoms > 1: # Distributional Q-learning uses a discrete support z # to represent the action value distribution @@ -106,7 +116,7 @@ def __init__(self, self.logits = support_logits_per_action self.dist = support_prob_per_action else: - action_scores_mean = tf.reduce_mean(action_scores, 1) + action_scores_mean = _reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) self.value = state_score + action_scores_centered @@ -175,11 +185,15 @@ class QValuePolicy(object): def __init__(self, q_values, observations, num_actions, stochastic, eps): deterministic_actions = tf.argmax(q_values, axis=1) batch_size = tf.shape(observations)[0] - random_actions = tf.random_uniform( - tf.stack([batch_size]), - minval=0, - maxval=num_actions, - dtype=tf.int64) + + # Special case masked out actions (q_value ~= -inf) so that we don't + # even consider them for exploration. + random_valid_action_logits = tf.where( + tf.equal(q_values, tf.float32.min), + tf.ones_like(q_values) * tf.float32.min, tf.ones_like(q_values)) + random_actions = tf.squeeze( + tf.multinomial(random_valid_action_logits, 1), axis=1) + chose_random = tf.random_uniform( tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps stochastic_actions = tf.where(chose_random, random_actions, @@ -241,6 +255,10 @@ def __init__(self, self.td_error = tf.nn.softmax_cross_entropy_with_logits( labels=m, logits=q_logits_t_selected) self.loss = tf.reduce_mean(self.td_error * importance_weights) + self.stats = { + # TODO: better Q stats for dist dqn + "mean_td_error": tf.reduce_mean(self.td_error), + } else: q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best @@ -252,6 +270,12 @@ def __init__(self, q_t_selected - tf.stop_gradient(q_t_selected_target)) self.loss = tf.reduce_mean( importance_weights * _huber_loss(self.td_error)) + self.stats = { + "mean_q": tf.reduce_mean(q_t_selected), + "min_q": tf.reduce_min(q_t_selected), + "max_q": tf.reduce_max(q_t_selected), + "mean_td_error": tf.reduce_mean(self.td_error), + } class DQNPolicyGraph(TFPolicyGraph): @@ -274,8 +298,8 @@ def __init__(self, observation_space, action_space, config): # Action Q network with tf.variable_scope(Q_SCOPE) as scope: - q_values, q_logits, q_dist = self._build_q_network( - self.cur_observations) + q_values, q_logits, q_dist, _ = self._build_q_network( + self.cur_observations, observation_space) self.q_func_vars = _scope_vars(scope.name) # Action outputs @@ -294,12 +318,17 @@ def __init__(self, observation_space, action_space, config): # q network evaluation with tf.variable_scope(Q_SCOPE, reuse=True): - q_t, q_logits_t, q_dist_t = self._build_q_network(self.obs_t) + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + q_t, q_logits_t, q_dist_t, model = self._build_q_network( + self.obs_t, observation_space) + q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - + prev_update_ops) # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: - q_tp1, q_logits_tp1, q_dist_tp1 = self._build_q_network( - self.obs_tp1) + q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network( + self.obs_tp1, observation_space) self.target_q_func_vars = _scope_vars(scope.name) # q scores for actions which we know were selected in the given state. @@ -312,8 +341,8 @@ def __init__(self, observation_space, action_space, config): if config["double_q"]: with tf.variable_scope(Q_SCOPE, reuse=True): q_tp1_using_online_net, q_logits_tp1_using_online_net, \ - q_dist_tp1_using_online_net = self._build_q_network( - self.obs_tp1) + q_dist_tp1_using_online_net, _ = self._build_q_network( + self.obs_tp1, observation_space) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = tf.one_hot( q_tp1_best_using_online_net, self.num_actions) @@ -358,35 +387,18 @@ def __init__(self, observation_space, action_space, config): self.sess, obs_input=self.cur_observations, action_sampler=self.output_actions, - loss=self.loss.loss, - loss_inputs=self.loss_inputs) + loss=model.loss() + self.loss.loss, + loss_inputs=self.loss_inputs, + update_ops=q_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) - def _build_q_network(self, obs): - qnet = QNetwork( - ModelCatalog.get_model(obs, 1, self.config["model"]), - self.num_actions, self.config["dueling"], self.config["hiddens"], - self.config["noisy"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"], self.config["sigma0"]) - return qnet.value, qnet.logits, qnet.dist - - def _build_q_value_policy(self, q_values): - return QValuePolicy(q_values, self.cur_observations, self.num_actions, - self.stochastic, self.eps).action - - def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best): - return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best, self.importance_weights, self.rew_t, - self.done_mask, self.config["gamma"], - self.config["n_step"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"]) - + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer( learning_rate=self.config["lr"], epsilon=self.config["adam_epsilon"]) + @override(TFPolicyGraph) def gradients(self, optimizer): if self.config["grad_norm_clipping"] is not None: grads_and_vars = _minimize_and_clip( @@ -400,20 +412,36 @@ def gradients(self, optimizer): grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] return grads_and_vars + @override(TFPolicyGraph) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, + "stats": self.loss.stats, } - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): return _postprocess_dqn(self, sample_batch) + @override(PolicyGraph) + def get_state(self): + return [TFPolicyGraph.get_state(self), self.cur_epsilon] + + @override(PolicyGraph) + def set_state(self, state): + TFPolicyGraph.set_state(self, state[0]) + self.set_epsilon(state[1]) + def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): td_err = self.sess.run( @@ -434,15 +462,31 @@ def update_target(self): def set_epsilon(self, epsilon): self.cur_epsilon = epsilon - def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] + def _build_q_network(self, obs, space): + qnet = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, space, self.num_actions, self.config["model"]), + self.num_actions, self.config["dueling"], self.config["hiddens"], + self.config["noisy"], self.config["num_atoms"], + self.config["v_min"], self.config["v_max"], self.config["sigma0"]) + return qnet.value, qnet.logits, qnet.dist, qnet.model - def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) - self.set_epsilon(state[1]) + def _build_q_value_policy(self, q_values): + return QValuePolicy(q_values, self.cur_observations, self.num_actions, + self.stochastic, self.eps).action + + def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best): + return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, self.importance_weights, self.rew_t, + self.done_mask, self.config["gamma"], + self.config["n_step"], self.config["num_atoms"], + self.config["v_min"], self.config["v_max"]) -def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): +def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): """Rewrites the given trajectory fragments to encode n-step rewards. reward[i] = ( @@ -475,9 +519,9 @@ def _postprocess_dqn(policy_graph, sample_batch): # N-step Q adjustments if policy_graph.config["n_step"] > 1: - adjust_nstep(policy_graph.config["n_step"], - policy_graph.config["gamma"], obs, actions, rewards, - new_obs, dones) + _adjust_nstep(policy_graph.config["n_step"], + policy_graph.config["gamma"], obs, actions, rewards, + new_obs, dones) batch = SampleBatch({ "obs": obs, @@ -500,6 +544,14 @@ def _postprocess_dqn(policy_graph, sample_batch): return batch +def _reduce_mean_ignore_inf(x, axis): + """Same as tf.reduce_mean() but ignores -inf values.""" + mask = tf.not_equal(x, tf.float32.min) + x_zeroed = tf.where(mask, x, tf.zeros_like(x)) + return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum( + tf.cast(mask, tf.float32), axis)) + + def _huber_loss(x, delta=1.0): """Reference: https://en.wikipedia.org/wiki/Huber_loss""" return tf.where( diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 1ce219b7c0ab1..4aa4a86aac889 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -6,25 +6,29 @@ from __future__ import print_function from collections import namedtuple +import logging import numpy as np import time import ray -from ray.rllib.agents import Agent -from ray.tune.trial import Resources +from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies -from ray.rllib.agents.es import tabular_logger as tlogger from ray.rllib.agents.es import utils -from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override +from ray.rllib.utils import FilterManager + +logger = logging.getLogger(__name__) Result = namedtuple("Result", [ "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", "eval_returns", "eval_lengths" ]) -DEFAULT_CONFIG = { +# yapf: disable +# __sphinx_doc_begin__ +DEFAULT_CONFIG = with_common_config({ "l2_coeff": 0.005, "noise_stdev": 0.02, "episodes_per_batch": 1000, @@ -36,9 +40,9 @@ "observation_filter": "MeanStdFilter", "noise_size": 250000000, "report_length": 10, - "env": None, - "env_config": {}, -} +}) +# __sphinx_doc_end__ +# yapf: enable @ray.remote @@ -76,12 +80,30 @@ def __init__(self, self.env = env_creator(config["env_config"]) from ray.rllib import models - self.preprocessor = models.ModelCatalog.get_preprocessor(self.env) + self.preprocessor = models.ModelCatalog.get_preprocessor( + self.env, config["model"]) self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( - self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], **policy_params) + self.sess, self.env.action_space, self.env.observation_space, + self.preprocessor, config["observation_filter"], config["model"], + **policy_params) + + @property + def filters(self): + return {"default": self.policy.get_filter()} + + def sync_filters(self, new_filters): + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters def rollout(self, timestep_limit, add_noise=True): rollout_rewards, rollout_length = policies.rollout( @@ -146,11 +168,7 @@ class ESAgent(Agent): _agent_name = "ES" _default_config = DEFAULT_CONFIG - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"]) - + @override(Agent) def _init(self): policy_params = {"action_noise_std": 0.01} @@ -160,18 +178,19 @@ def _init(self): self.sess = utils.make_session(single_threaded=False) self.policy = policies.GenericPolicy( - self.sess, env.action_space, preprocessor, - self.config["observation_filter"], **policy_params) + self.sess, env.action_space, env.observation_space, preprocessor, + self.config["observation_filter"], self.config["model"], + **policy_params) self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"]) self.report_length = self.config["report_length"] # Create the shared noise table. - print("Creating shared noise table.") + logger.info("Creating shared noise table.") noise_id = create_shared_noise.remote(self.config["noise_size"]) self.noise = SharedNoiseTable(ray.get(noise_id)) # Create the actors. - print("Creating actors.") + logger.info("Creating actors.") self.workers = [ Worker.remote(self.config, policy_params, self.env_creator, noise_id) for _ in range(self.config["num_workers"]) @@ -181,26 +200,7 @@ def _init(self): self.reward_list = [] self.tstart = time.time() - def _collect_results(self, theta_id, min_episodes, min_timesteps): - num_episodes, num_timesteps = 0, 0 - results = [] - while num_episodes < min_episodes or num_timesteps < min_timesteps: - print("Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) - rollout_ids = [ - worker.do_rollouts.remote(theta_id) for worker in self.workers - ] - # Get the results of the rollouts. - for result in ray.get(rollout_ids): - results.append(result) - # Update the number of episodes and the number of timesteps - # keeping in mind that result.noisy_lengths is a list of lists, - # where the inner lists have length 2. - num_episodes += sum(len(pair) for pair in result.noisy_lengths) - num_timesteps += sum( - sum(pair) for pair in result.noisy_lengths) - return results, num_episodes, num_timesteps - + @override(Agent) def _train(self): config = self.config @@ -266,20 +266,10 @@ def _train(self): if len(all_eval_returns) > 0: self.reward_list.append(np.mean(eval_returns)) - tlogger.record_tabular("EvalEpRewStd", eval_returns.std()) - tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean()) - - tlogger.record_tabular("EpRewMean", noisy_returns.mean()) - tlogger.record_tabular("EpRewStd", noisy_returns.std()) - tlogger.record_tabular("EpLenMean", noisy_lengths.mean()) - - tlogger.record_tabular("Norm", float(np.square(theta).sum())) - tlogger.record_tabular("GradNorm", float(np.square(g).sum())) - tlogger.record_tabular("UpdateRatio", float(update_ratio)) - - tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size) - tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far) - tlogger.dump_tabular() + # Now sync the filters + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) info = { "weights_norm": np.square(theta).sum(), @@ -298,20 +288,49 @@ def _train(self): return result + @override(Agent) + def compute_action(self, observation): + return self.policy.compute(observation, update=False)[0] + + @override(Agent) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() + def _collect_results(self, theta_id, min_episodes, min_timesteps): + num_episodes, num_timesteps = 0, 0 + results = [] + while num_episodes < min_episodes or num_timesteps < min_timesteps: + logger.info( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) + rollout_ids = [ + worker.do_rollouts.remote(theta_id) for worker in self.workers + ] + # Get the results of the rollouts. + for result in ray.get(rollout_ids): + results.append(result) + # Update the number of episodes and the number of timesteps + # keeping in mind that result.noisy_lengths is a list of lists, + # where the inner lists have length 2. + num_episodes += sum(len(pair) for pair in result.noisy_lengths) + num_timesteps += sum( + sum(pair) for pair in result.noisy_lengths) + + return results, num_episodes, num_timesteps + def __getstate__(self): return { "weights": self.policy.get_weights(), + "filter": self.policy.get_filter(), "episodes_so_far": self.episodes_so_far, } def __setstate__(self, state): - self.policy.set_weights(state["weights"]) self.episodes_so_far = state["episodes_so_far"] - - def compute_action(self, observation): - return self.policy.compute(observation, update=False)[0] + self.policy.set_weights(state["weights"]) + self.policy.set_filter(state["filter"]) + FilterManager.synchronize({ + "default": self.policy.get_filter() + }, self.workers) diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index d62fee43c4c57..cf2da630e0866 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -10,6 +10,7 @@ import tensorflow as tf import ray +from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.models import ModelCatalog from ray.rllib.utils.filter import get_filter @@ -38,8 +39,8 @@ def rollout(policy, env, timestep_limit=None, add_noise=False): class GenericPolicy(object): - def __init__(self, sess, action_space, preprocessor, observation_filter, - action_noise_std): + def __init__(self, sess, action_space, obs_space, preprocessor, + observation_filter, model_options, action_noise_std): self.sess = sess self.action_space = action_space self.action_noise_std = action_noise_std @@ -51,8 +52,10 @@ def __init__(self, sess, action_space, preprocessor, observation_filter, # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( - self.action_space, dist_type="deterministic") - model = ModelCatalog.get_model(self.inputs, dist_dim) + self.action_space, model_options, dist_type="deterministic") + model = ModelCatalog.get_model({ + "obs": self.inputs + }, obs_space, dist_dim, model_options) dist = dist_class(model.outputs) self.sampler = dist.sample() @@ -69,6 +72,7 @@ def compute(self, observation, add_noise=False, update=True): observation = self.observation_filter(observation[None], update=update) action = self.sess.run( self.sampler, feed_dict={self.inputs: observation}) + action = _unbatch_tuple_actions(action) if add_noise and isinstance(self.action_space, gym.spaces.Box): action += np.random.randn(*action.shape) * self.action_noise_std return action @@ -78,3 +82,9 @@ def set_weights(self, x): def get_weights(self): return self.variables.get_flat() + + def get_filter(self): + return self.observation_filter + + def set_filter(self, observation_filter): + self.observation_filter = observation_filter diff --git a/python/ray/rllib/agents/es/tabular_logger.py b/python/ray/rllib/agents/es/tabular_logger.py deleted file mode 100644 index 1463e59e07046..0000000000000 --- a/python/ray/rllib/agents/es/tabular_logger.py +++ /dev/null @@ -1,229 +0,0 @@ -# Code in this file is copied and adapted from -# https://github.com/openai/evolution-strategies-starter. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import OrderedDict -import os -import sys -import time - -import tensorflow as tf -from tensorflow.core.util import event_pb2 -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.util import compat - -DEBUG = 10 -INFO = 20 -WARN = 30 -ERROR = 40 - -DISABLED = 50 - - -class TbWriter(object): - """Based on SummaryWriter, but changed to allow for a different prefix.""" - - def __init__(self, dir, prefix): - self.dir = dir - # Start at 1, because EvWriter automatically generates an object with - # step = 0. - self.step = 1 - self.evwriter = pywrap_tensorflow.EventsWriter( - compat.as_bytes(os.path.join(dir, prefix))) - - def write_values(self, key2val): - summary = tf.Summary(value=[ - tf.Summary.Value(tag=k, simple_value=float(v)) - for (k, v) in key2val.items() - ]) - event = event_pb2.Event(wall_time=time.time(), summary=summary) - event.step = self.step - self.evwriter.WriteEvent(event) - self.evwriter.Flush() - self.step += 1 - - def close(self): - self.evwriter.Close() - - -# API - - -def start(dir): - if _Logger.CURRENT is not _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to start logging (dir=%s), but " - "you never stopped the previous logger (dir=%s)." - "\n" % (dir, _Logger.CURRENT.dir)) - _Logger.CURRENT = _Logger(dir=dir) - - -def stop(): - if _Logger.CURRENT is _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to stop logging, but you never " - "started any previous logger." - "\n" % (dir, _Logger.CURRENT.dir)) - return - _Logger.CURRENT.close() - _Logger.CURRENT = _Logger.DEFAULT - - -def record_tabular(key, val): - """Log a value of some diagnostic. - - Call this once for each diagnostic quantity, each iteration. - """ - _Logger.CURRENT.record_tabular(key, val) - - -def dump_tabular(): - """Write all of the diagnostics from the current iteration.""" - _Logger.CURRENT.dump_tabular() - - -def log(*args, **kwargs): - """Write the sequence of args, with no separators. - - This is written to the console and output files (if you've configured an - output file). - """ - level = kwargs['level'] if 'level' in kwargs else INFO - _Logger.CURRENT.log(*args, level=level) - - -def debug(*args): - log(*args, level=DEBUG) - - -def info(*args): - log(*args, level=INFO) - - -def warn(*args): - log(*args, level=WARN) - - -def error(*args): - log(*args, level=ERROR) - - -def set_level(level): - """ - Set logging threshold on current logger. - """ - _Logger.CURRENT.set_level(level) - - -def get_dir(): - """ - Get directory that log files are being written to. - will be None if there is no output directory (i.e., if you didn't call - start) - """ - return _Logger.CURRENT.get_dir() - - -def get_expt_dir(): - sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n") - return get_dir() - - -# Backend - - -class _Logger(object): - # A logger with no output files. (See right below class definition) so that - # you can still log to the terminal without setting up any output files. - DEFAULT = None - # Current logger being used by the free functions above. - CURRENT = None - - def __init__(self, dir=None): - self.name2val = OrderedDict() # Values this iteration. - self.level = INFO - self.dir = dir - self.text_outputs = [sys.stdout] - if dir is not None: - os.makedirs(dir, exist_ok=True) - self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w")) - self.tbwriter = TbWriter(dir=dir, prefix="events") - else: - self.tbwriter = None - - # Logging API, forwarded - - def record_tabular(self, key, val): - self.name2val[key] = val - - def dump_tabular(self): - # Create strings for printing. - key2str = OrderedDict() - for (key, val) in self.name2val.items(): - if hasattr(val, "__float__"): - valstr = "%-8.3g" % val - else: - valstr = val - key2str[self._truncate(key)] = self._truncate(valstr) - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) - # Write to all text outputs - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for (key, val) in key2str.items(): - self._write_text("| ", key, " " * (keywidth - len(key)), " | ", - val, " " * (valwidth - len(val)), " |\n") - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for f in self.text_outputs: - try: - f.flush() - except OSError: - sys.stderr.write('Warning! OSError when flushing.\n') - # Write to tensorboard - if self.tbwriter is not None: - self.tbwriter.write_values(self.name2val) - self.name2val.clear() - - def log(self, *args, **kwargs): - level = kwargs['level'] if 'level' in kwargs else INFO - if self.level <= level: - self._do_log(*args) - - # Configuration - - def set_level(self, level): - self.level = level - - def get_dir(self): - return self.dir - - def close(self): - for f in self.text_outputs[1:]: - f.close() - if self.tbwriter: - self.tbwriter.close() - - # Misc - - def _do_log(self, *args): - self._write_text(*args + ('\n', )) - for f in self.text_outputs: - try: - f.flush() - except OSError: - print('Warning! OSError when flushing.') - - def _write_text(self, *strings): - for f in self.text_outputs: - for string in strings: - f.write(string) - - def _truncate(self, s): - if len(s) > 33: - return s[:30] + "..." - else: - return s - - -_Logger.DEFAULT = _Logger() -_Logger.CURRENT = _Logger.DEFAULT diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index cfa55bd735c88..aa789387f9300 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -8,13 +8,24 @@ from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer -from ray.tune.trial import Resources +from ray.rllib.utils.annotations import override OPTIMIZER_SHARED_CONFIGS = [ + "lr", + "num_envs_per_worker", + "num_gpus", "sample_batch_size", "train_batch_size", + "replay_buffer_num_slots", + "replay_proportion", + "num_parallel_data_loaders", + "grad_clip", + "max_sample_requests_in_flight_per_worker", + "broadcast_interval", ] +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # V-trace params (see vtrace.py). "vtrace": True, @@ -25,10 +36,22 @@ "sample_batch_size": 50, "train_batch_size": 500, "min_iter_time_s": 10, - "gpu": True, "num_workers": 2, - "num_cpus_per_worker": 1, - "num_gpus_per_worker": 0, + # number of GPUs the learner should use. + "num_gpus": 1, + # set >1 to load data into GPUs in parallel. Increases GPU memory usage + # proportionally with the number of loaders. + "num_parallel_data_loaders": 1, + # level of queuing for sampling. + "max_sample_requests_in_flight_per_worker": 2, + # max number of workers to broadcast one set of weights to + "broadcast_interval": 1, + # set >0 to enable experience replay. Saved samples will be replayed with + # a p:1 proportion to new data samples. + "replay_proportion": 0.0, + # number of sample batches to store for replay. The number of transitions + # saved total will be (replay_buffer_num_slots * sample_batch_size). + "replay_buffer_num_slots": 100, # Learning params. "grad_clip": 40.0, @@ -43,14 +66,9 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": -0.01, - - # Model and preprocessor options. - "model": { - "use_lstm": False, - "max_seq_len": 20, - "dim": 84, - }, }) +# __sphinx_doc_end__ +# yapf: enable class ImpalaAgent(Agent): @@ -60,15 +78,7 @@ class ImpalaAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = VTracePolicyGraph - @classmethod - def default_resource_request(cls, config): - cf = dict(cls._default_config, **config) - return Resources( - cpu=1, - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - + @override(Agent) def _init(self): for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer"]: @@ -80,19 +90,20 @@ def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, policy_cls) self.remote_evaluators = self.make_remote_evaluators( - self.env_creator, policy_cls, self.config["num_workers"], - {"num_cpus": 1}) + self.env_creator, policy_cls, self.config["num_workers"]) self.optimizer = AsyncSamplesOptimizer(self.local_evaluator, self.remote_evaluators, self.config["optimizer"]) + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() self.optimizer.step() while time.time() - start < self.config["min_iter_time_s"]: self.optimizer.step() - result = self.optimizer.collect_metrics() + result = self.optimizer.collect_metrics( + self.config["collect_metrics_timeout"]) result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 23f88e51f51b7..5eed0a6e79e5f 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -11,12 +11,14 @@ import ray from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.misc import linear, normc_initializer +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.models.action_dist import Categorical class VTraceLoss(object): @@ -31,6 +33,7 @@ def __init__(self, rewards, values, bootstrap_value, + valid_mask, vf_loss_coeff=0.5, entropy_coeff=-0.01, clip_rho_threshold=1.0, @@ -52,6 +55,7 @@ def __init__(self, rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + valid_mask: A bool tensor of valid RNN input elements (#2992). """ # Compute vtrace on the CPU for better perf. @@ -70,14 +74,16 @@ def __init__(self, # The policy gradients loss self.pi_loss = -tf.reduce_sum( - actions_logp * self.vtrace_returns.pg_advantages) + tf.boolean_mask(actions_logp * self.vtrace_returns.pg_advantages, + valid_mask)) # The baseline loss - delta = values - self.vtrace_returns.vs + delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) # The entropy loss - self.entropy = tf.reduce_sum(actions_entropy) + self.entropy = tf.reduce_sum( + tf.boolean_mask(actions_entropy, valid_mask)) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff + @@ -85,40 +91,62 @@ def __init__(self, class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): - def __init__(self, observation_space, action_space, config): + def __init__(self, + observation_space, + action_space, + config, + existing_inputs=None): config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) assert config["batch_mode"] == "truncate_episodes", \ "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() + # Create input placeholders + if existing_inputs: + actions, dones, behaviour_logits, rewards, observations, \ + prev_actions, prev_rewards = existing_inputs[:7] + existing_state_in = existing_inputs[7:-1] + existing_seq_lens = existing_inputs[-1] + else: + if isinstance(action_space, gym.spaces.Discrete): + ac_size = action_space.n + actions = tf.placeholder(tf.int64, [None], name="ac") + else: + raise UnsupportedSpaceException( + "Action space {} is not supported for IMPALA.".format( + action_space)) + dones = tf.placeholder(tf.bool, [None], name="dones") + rewards = tf.placeholder(tf.float32, [None], name="rewards") + behaviour_logits = tf.placeholder( + tf.float32, [None, ac_size], name="behaviour_logits") + observations = tf.placeholder( + tf.float32, [None] + list(observation_space.shape)) + existing_state_in = None + existing_seq_lens = None + # Setup the policy - self.observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) - self.model = ModelCatalog.get_model(self.observations, logit_dim, - self.config["model"]) + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") + self.model = ModelCatalog.get_model( + { + "obs": observations, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + }, + observation_space, + logit_dim, + self.config["model"], + state_in=existing_state_in, + seq_lens=existing_seq_lens) action_dist = dist_class(self.model.outputs) - values = tf.reshape( - linear(self.model.last_layer, 1, "value", normc_initializer(1.0)), - [-1]) + values = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - # Setup the policy loss - if isinstance(action_space, gym.spaces.Discrete): - ac_size = action_space.n - actions = tf.placeholder(tf.int64, [None], name="ac") - else: - raise UnsupportedSpaceException( - "Action space {} is not supported for IMPALA.".format( - action_space)) - dones = tf.placeholder(tf.bool, [None], name="dones") - rewards = tf.placeholder(tf.float32, [None], name="rewards") - behaviour_logits = tf.placeholder( - tf.float32, [None, ac_size], name="behaviour_logits") - def to_batches(tensor): if self.config["model"]["use_lstm"]: B = tf.shape(self.model.seq_lens)[0] @@ -126,8 +154,7 @@ def to_batches(tensor): else: # Important: chop the tensor into batches at known episode cut # boundaries. TODO(ekl) this is kind of a hack - T = (self.config["sample_batch_size"] // - self.config["num_envs_per_worker"]) + T = self.config["sample_batch_size"] B = tf.shape(tensor)[0] // T rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) @@ -136,6 +163,13 @@ def to_batches(tensor): rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + if self.model.state_in: + max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 + mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( actions=to_batches(actions)[:-1], @@ -148,18 +182,29 @@ def to_batches(tensor): rewards=to_batches(rewards)[:-1], values=to_batches(values)[:-1], bootstrap_value=to_batches(values)[-1], + valid_mask=to_batches(mask)[:-1], vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) + # KL divergence between worker and learner logits for debugging + model_dist = Categorical(self.model.outputs) + behaviour_dist = Categorical(behaviour_logits) + self.KLs = model_dist.kl(behaviour_dist) + self.mean_KL = tf.reduce_mean(self.KLs) + self.max_KL = tf.reduce_max(self.KLs) + self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0) + # Initialize TFPolicyGraph loss_in = [ ("actions", actions), ("dones", dones), ("behaviour_logits", behaviour_logits), ("rewards", rewards), - ("obs", self.observations), + ("obs", observations), + ("prev_actions", prev_actions), + ("prev_rewards", prev_rewards), ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) @@ -168,14 +213,17 @@ def to_batches(tensor): observation_space, action_space, self.sess, - obs_input=self.observations, + obs_input=observations, action_sampler=action_dist.sample(), - loss=self.loss.total_loss, + loss=self.model.loss() + self.loss.total_loss, loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"]) + max_seq_len=self.config["model"]["max_seq_len"], + batch_divisibility_req=self.config["sample_batch_size"]) self.sess.run(tf.global_variables_initializer()) @@ -190,9 +238,21 @@ def to_batches(tensor): "vf_explained_var": explained_variance( tf.reshape(self.loss.vtrace_returns.vs, [-1]), tf.reshape(to_batches(values)[:-1], [-1])), + "mean_KL": self.mean_KL, + "max_KL": self.max_KL, + "median_KL": self.median_KL, }, } + @override(TFPolicyGraph) + def copy(self, existing_inputs): + return VTracePolicyGraph( + self.observation_space, + self.action_space, + self.config, + existing_inputs=existing_inputs) + + @override(TFPolicyGraph) def optimizer(self): if self.config["opt_type"] == "adam": return tf.train.AdamOptimizer(self.cur_lr) @@ -201,21 +261,29 @@ def optimizer(self): self.config["momentum"], self.config["epsilon"]) + @override(TFPolicyGraph) def gradients(self, optimizer): grads = tf.gradients(self.loss.total_loss, self.var_list) self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads + @override(TFPolicyGraph) def extra_compute_action_fetches(self): return {"behaviour_logits": self.model.outputs} + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return self.stats_fetches - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): del sample_batch.data["new_obs"] # not used, so save some bandwidth return sample_batch + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py index 526ec146a8e2f..f4bf909918095 100644 --- a/python/ray/rllib/agents/mock.py +++ b/python/ray/rllib/agents/mock.py @@ -6,18 +6,19 @@ import pickle import numpy as np -from ray.rllib.agents.agent import Agent +from ray.rllib.agents.agent import Agent, with_common_config class _MockAgent(Agent): """Mock agent for use in tests""" _agent_name = "MockAgent" - _default_config = { + _default_config = with_common_config({ "mock_error": False, "persistent_error": False, - "test_variable": 1 - } + "test_variable": 1, + "num_workers": 0, + }) def _init(self): self.info = None @@ -59,13 +60,14 @@ class _SigmoidFakeData(_MockAgent): This can be helpful for evaluating early stopping algorithms.""" _agent_name = "SigmoidFakeData" - _default_config = { + _default_config = with_common_config({ "width": 100, "height": 100, "offset": 0, "iter_time": 10, "iter_timesteps": 1, - } + "num_workers": 0, + }) def _train(self): i = max(0, self.iteration - self.config["offset"]) @@ -82,13 +84,14 @@ def _train(self): class _ParameterTuningAgent(_MockAgent): _agent_name = "ParameterTuningAgent" - _default_config = { + _default_config = with_common_config({ "reward_amt": 10, "dummy_param": 10, "dummy_param2": 15, "iter_time": 10, - "iter_timesteps": 1 - } + "iter_timesteps": 1, + "num_workers": 0, + }) def _train(self): return dict( diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index e1766e7744f23..69c6761863c6a 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -5,22 +5,18 @@ from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.utils import merge_dicts -from ray.tune.trial import Resources +from ray.rllib.utils.annotations import override +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # No remote workers by default "num_workers": 0, # Learning rate "lr": 0.0004, - # Override model config - "model": { - # Use LSTM model. - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - }, }) +# __sphinx_doc_end__ +# yapf: enable class PGAgent(Agent): @@ -34,25 +30,22 @@ class PGAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = PGPolicyGraph - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"]) - + @override(Agent) def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( - self.env_creator, self._policy_graph, self.config["num_workers"], - {}) + self.env_creator, self._policy_graph, self.config["num_workers"]) self.optimizer = SyncSamplesOptimizer(self.local_evaluator, self.remote_evaluators, self.config["optimizer"]) + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled self.optimizer.step() - result = self.optimizer.collect_metrics() + result = self.optimizer.collect_metrics( + self.config["collect_metrics_timeout"]) result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index bb831c47d4ee4..59e9a9effc12b 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -7,25 +7,39 @@ import ray from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils.annotations import override class PGLoss(object): + """Simple policy gradient loss.""" + def __init__(self, action_dist, actions, advantages): self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages) class PGPolicyGraph(TFPolicyGraph): + """Simple policy gradient example of defining a policy graph.""" + def __init__(self, obs_space, action_space, config): config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config) self.config = config - # Setup policy + # Setup placeholders obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) - self.model = ModelCatalog.get_model( - obs, self.logit_dim, options=self.config["model"]) + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") + + # Create the model network and action outputs + self.model = ModelCatalog.get_model({ + "obs": obs, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + }, obs_space, self.logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) # logit for each action # Setup policy loss @@ -33,14 +47,19 @@ def __init__(self, obs_space, action_space, config): advantages = tf.placeholder(tf.float32, [None], name="adv") loss = PGLoss(action_dist, actions, advantages).loss - # Initialize TFPolicyGraph - sess = tf.get_default_session() + # Mapping from sample batch keys to placeholders. These keys will be + # read from postprocessed sample batches and fed into the specified + # placeholders during loss computation. loss_in = [ ("obs", obs), ("actions", actions), - ("advantages", advantages), + ("prev_actions", prev_actions), + ("prev_rewards", prev_rewards), + ("advantages", advantages), # added during postprocessing ] + # Initialize TFPolicyGraph + sess = tf.get_default_session() TFPolicyGraph.__init__( self, obs_space, @@ -48,17 +67,25 @@ def __init__(self, obs_space, action_space, config): sess, obs_input=obs, action_sampler=action_dist.sample(), - loss=loss, + loss=self.model.loss() + loss, loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, seq_lens=self.model.seq_lens, max_seq_len=config["model"]["max_seq_len"]) sess.run(tf.global_variables_initializer()) - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + # This adds the "advantages" column to the sample batch return compute_advantages( sample_batch, 0.0, self.config["gamma"], use_gae=False) + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index f452f789397ea..0c10b279ab221 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -2,12 +2,17 @@ from __future__ import division from __future__ import print_function +import logging + from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph -from ray.rllib.utils import merge_dicts from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer -from ray.tune.trial import Resources +from ray.rllib.utils.annotations import override + +logger = logging.getLogger(__name__) +# yapf: disable +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # If true, use the Generalized Advantage Estimator (GAE) # with a value function, see https://arxiv.org/pdf/1506.02438.pdf. @@ -20,7 +25,7 @@ "sample_batch_size": 200, # Number of timesteps collected for each SGD round "train_batch_size": 4000, - # Total SGD batch size across all devices for SGD (multi-gpu only) + # Total SGD batch size across all devices for SGD "sgd_minibatch_size": 128, # Number of SGD iterations in each outer loop "num_sgd_iter": 30, @@ -41,26 +46,16 @@ "vf_clip_param": 10.0, # Target value for KL divergence "kl_target": 0.01, - # Number of GPUs to use for SGD - "num_gpus": 0, - # Whether to allocate GPUs for workers (if > 0). - "num_gpus_per_worker": 0, - # Whether to allocate CPUs for workers (if > 0). - "num_cpus_per_worker": 1, # Whether to rollout "complete_episodes" or "truncate_episodes" - "batch_mode": "complete_episodes", + "batch_mode": "truncate_episodes", # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", - # Use the sync samples optimizer instead of the multi-gpu one + # Uses the sync samples optimizer instead of the multi-gpu one. This does + # not support minibatches. "simple_optimizer": False, - # Override model config - "model": { - # Whether to use LSTM model - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - }, }) +# __sphinx_doc_end__ +# yapf: enable class PPOAgent(Agent): @@ -70,39 +65,18 @@ class PPOAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = PPOPolicyGraph - @classmethod - def default_resource_request(cls, config): - cf = merge_dicts(cls._default_config, config) - return Resources( - cpu=1, - gpu=cf["num_gpus"], - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - + @override(Agent) def _init(self): - waste_ratio = ( - self.config["sample_batch_size"] * self.config["num_workers"] / - self.config["train_batch_size"]) - if waste_ratio > 1: - msg = ("sample_batch_size * num_workers >> train_batch_size. " - "This means that many steps will be discarded. Consider " - "reducing sample_batch_size, or increase train_batch_size.") - if waste_ratio > 1.5: - raise ValueError(msg) - else: - print("Warning: " + msg) + self._validate_config() self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( - self.env_creator, self._policy_graph, self.config["num_workers"], { - "num_cpus": self.config["num_cpus_per_worker"], - "num_gpus": self.config["num_gpus_per_worker"] - }) + self.env_creator, self._policy_graph, self.config["num_workers"]) if self.config["simple_optimizer"]: self.optimizer = SyncSamplesOptimizer( self.local_evaluator, self.remote_evaluators, { "num_sgd_iter": self.config["num_sgd_iter"], - "train_batch_size": self.config["train_batch_size"] + "train_batch_size": self.config["train_batch_size"], }) else: self.optimizer = LocalMultiGPUOptimizer( @@ -114,6 +88,7 @@ def _init(self): "standardize_fields": ["advantages"], }) + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() @@ -125,8 +100,42 @@ def _train(self): # multi-agent self.local_evaluator.foreach_trainable_policy( lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"])) - res = self.optimizer.collect_metrics() + res = self.optimizer.collect_metrics( + self.config["collect_metrics_timeout"]) res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=dict(fetches, **res.get("info", {}))) return res + + def _validate_config(self): + waste_ratio = ( + self.config["sample_batch_size"] * self.config["num_workers"] / + self.config["train_batch_size"]) + if waste_ratio > 1: + msg = ("sample_batch_size * num_workers >> train_batch_size. " + "This means that many steps will be discarded. Consider " + "reducing sample_batch_size, or increase train_batch_size.") + if waste_ratio > 1.5: + raise ValueError(msg) + else: + logger.warn(msg) + if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + self.config["sgd_minibatch_size"], + self.config["train_batch_size"])) + if (self.config["batch_mode"] == "truncate_episodes" + and not self.config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value function") + if (self.config["multiagent"]["policy_graphs"] + and not self.config["simple_optimizer"]): + logger.info( + "In multi-agent mode, policies will be optimized sequentially " + "by the multi-GPU optimizer. Consider setting " + "simple_optimizer=True if this doesn't work for you.") + if self.config["observation_filter"] != "NoFilter": + # TODO(ekl): consider setting the default to be NoFilter + logger.warn( + "By default, observations will be normalized with {}".format( + self.config["observation_filter"])) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index e6fc90d1ce948..6948d810a26be 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -6,10 +6,11 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.misc import linear, normc_initializer +from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance @@ -24,6 +25,7 @@ def __init__(self, curr_action_dist, value_fn, cur_kl_coeff, + valid_mask, entropy_coeff=0, clip_param=0.1, vf_clip_param=0.1, @@ -48,28 +50,33 @@ def __init__(self, value_fn (Tensor): Current value function output Tensor. cur_kl_coeff (Variable): Variable holding the current PPO KL coefficient. + valid_mask (Tensor): A bool mask of valid input elements (#2992). entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter vf_clip_param (float): Clip parameter for the value function vf_loss_coeff (float): Coefficient of the value function loss use_gae (bool): If true, use the Generalized Advantage Estimator. """ - dist_cls, _ = ModelCatalog.get_action_dist(action_space) + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) + + dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( curr_action_dist.logp(actions) - prev_dist.logp(actions)) action_kl = prev_dist.kl(curr_action_dist) - self.mean_kl = tf.reduce_mean(action_kl) + self.mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() - self.mean_entropy = tf.reduce_mean(curr_entropy) + self.mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param)) - self.mean_policy_loss = tf.reduce_mean(-surrogate_loss) + self.mean_policy_loss = reduce_mean_valid(-surrogate_loss) if use_gae: vf_loss1 = tf.square(value_fn - value_targets) @@ -77,14 +84,15 @@ def __init__(self, value_fn - vf_preds, -vf_clip_param, vf_clip_param) vf_loss2 = tf.square(vf_clipped - value_targets) vf_loss = tf.maximum(vf_loss1, vf_loss2) - self.mean_vf_loss = tf.reduce_mean(vf_loss) - loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl + - vf_loss_coeff * vf_loss - - entropy_coeff * curr_entropy) + self.mean_vf_loss = reduce_mean_valid(vf_loss) + loss = reduce_mean_valid( + -surrogate_loss + cur_kl_coeff * action_kl + + vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy) else: self.mean_vf_loss = tf.constant(0.0) - loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl - - entropy_coeff * curr_entropy) + loss = reduce_mean_valid(-surrogate_loss + + cur_kl_coeff * action_kl - + entropy_coeff * curr_entropy) self.loss = loss @@ -108,12 +116,14 @@ def __init__(self, self.config = config self.kl_coeff_val = self.config["kl_coeff"] self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space) + dist_cls, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) if existing_inputs: obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph = existing_inputs[:6] - existing_state_in = existing_inputs[6:-1] + logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \ + existing_inputs[:8] + existing_state_in = existing_inputs[8:-1] existing_seq_lens = existing_inputs[-1] else: obs_ph = tf.placeholder( @@ -129,6 +139,9 @@ def __init__(self, tf.float32, name="vf_preds", shape=(None, )) value_targets_ph = tf.placeholder( tf.float32, name="value_targets", shape=(None, )) + prev_actions_ph = ModelCatalog.get_action_placeholder(action_space) + prev_rewards_ph = tf.placeholder( + tf.float32, [None], name="prev_reward") existing_state_in = None existing_seq_lens = None self.observations = obs_ph @@ -140,9 +153,17 @@ def __init__(self, ("actions", act_ph), ("logits", logits_ph), ("vf_preds", vf_preds_ph), + ("prev_actions", prev_actions_ph), + ("prev_rewards", prev_rewards_ph), ] self.model = ModelCatalog.get_model( - obs_ph, + { + "obs": obs_ph, + "prev_actions": prev_actions_ph, + "prev_rewards": prev_rewards_ph, + "is_training": self._get_is_training_placeholder(), + }, + observation_space, logit_dim, self.config["model"], state_in=existing_state_in, @@ -161,9 +182,7 @@ def __init__(self, self.sampler = curr_action_dist.sample() if self.config["use_gae"]: if self.config["vf_share_layers"]: - self.value_function = tf.reshape( - linear(self.model.last_layer, 1, "value", - normc_initializer(1.0)), [-1]) + self.value_function = self.model.value_function() else: vf_config = self.config["model"].copy() # Do not split the last layer of the value function into @@ -172,12 +191,23 @@ def __init__(self, vf_config["free_log_std"] = False vf_config["use_lstm"] = False with tf.variable_scope("value_function"): - self.value_function = ModelCatalog.get_model( - obs_ph, 1, vf_config).outputs + self.value_function = ModelCatalog.get_model({ + "obs": obs_ph, + "prev_actions": prev_actions_ph, + "prev_rewards": prev_rewards_ph, + "is_training": self._get_is_training_placeholder(), + }, observation_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) + if self.model.state_in: + max_seq_len = tf.reduce_max(self.model.seq_lens) + mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(adv_ph) + self.loss_obj = PPOLoss( action_space, value_targets_ph, @@ -188,6 +218,7 @@ def __init__(self, curr_action_dist, self.value_function, self.kl_coeff, + mask, entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], @@ -203,10 +234,12 @@ def __init__(self, self.sess, obs_input=obs_ph, action_sampler=self.sampler, - loss=self.loss_obj.loss, + loss=self.model.loss() + self.loss_obj.loss, loss_inputs=self.loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, + prev_action_input=prev_actions_ph, + prev_reward_input=prev_rewards_ph, seq_lens=self.model.seq_lens, max_seq_len=config["model"]["max_seq_len"]) @@ -214,6 +247,7 @@ def __init__(self, self.explained_variance = explained_variance(value_targets_ph, self.value_function) self.stats_fetches = { + "cur_kl_coeff": self.kl_coeff, "cur_lr": tf.cast(self.cur_lr, tf.float64), "total_loss": self.loss_obj.loss, "policy_loss": self.loss_obj.mean_policy_loss, @@ -223,38 +257,20 @@ def __init__(self, "entropy": self.loss_obj.mean_entropy } + @override(TFPolicyGraph) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" return PPOPolicyGraph( - None, + self.observation_space, self.action_space, self.config, existing_inputs=existing_inputs) - def extra_compute_action_fetches(self): - return {"vf_preds": self.value_function, "logits": self.logits} - - def extra_compute_grad_fetches(self): - return self.stats_fetches - - def update_kl(self, sampled_kl): - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff_val *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self.sess) - return self.kl_coeff_val - - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) - return vf[0] - - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): completed = sample_batch["dones"][-1] if completed: last_r = 0.0 @@ -262,7 +278,7 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): next_state = [] for i in range(len(self.model.state_in)): next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) + last_r = self._value(sample_batch["new_obs"][-1], *next_state) batch = compute_advantages( sample_batch, last_r, @@ -271,9 +287,36 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): use_gae=self.config["use_gae"]) return batch + @override(TFPolicyGraph) def gradients(self, optimizer): return optimizer.compute_gradients( self._loss, colocate_gradients_with_ops=True) + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return {"vf_preds": self.value_function, "logits": self.logits} + + @override(TFPolicyGraph) + def extra_compute_grad_fetches(self): + return self.stats_fetches + + def update_kl(self, sampled_kl): + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + self.kl_coeff.load(self.kl_coeff_val, session=self.sess) + return self.kl_coeff_val + + def _value(self, ob, *args): + feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self.sess.run(self.value_function, feed_dict) + return vf[0] diff --git a/python/ray/rllib/env/__init__.py b/python/ray/rllib/env/__init__.py index 752d27cecf674..2e9ee49745d22 100644 --- a/python/ray/rllib/env/__init__.py +++ b/python/ray/rllib/env/__init__.py @@ -1,9 +1,11 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.serving_env import ServingEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.env_context import EnvContext __all__ = [ - "AsyncVectorEnv", "MultiAgentEnv", "ServingEnv", "VectorEnv", "EnvContext" + "AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv", + "ServingEnv", "EnvContext" ] diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py index c2e5ab1d30864..72cd812de1eae 100644 --- a/python/ray/rllib/env/async_vector_env.py +++ b/python/ray/rllib/env/async_vector_env.py @@ -2,9 +2,10 @@ from __future__ import division from __future__ import print_function -from ray.rllib.env.serving_env import ServingEnv +from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import override class AsyncVectorEnv(object): @@ -20,7 +21,13 @@ class AsyncVectorEnv(object): gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv rllib.MultiAgentEnv => rllib.AsyncVectorEnv - rllib.ServingEnv => rllib.AsyncVectorEnv + rllib.ExternalEnv => rllib.AsyncVectorEnv + + Attributes: + action_space (gym.Space): Action space. This must be defined for + single-agent envs. Multi-agent envs can set this to None. + observation_space (gym.Space): Observation space. This must be defined + for single-agent envs. Multi-agent envs can set this to None. Examples: >>> env = MyAsyncVectorEnv() @@ -64,11 +71,11 @@ def wrap_async(env, make_env=None, num_envs=1): if isinstance(env, MultiAgentEnv): env = _MultiAgentEnvToAsync( make_env=make_env, existing_envs=[env], num_envs=num_envs) - elif isinstance(env, ServingEnv): + elif isinstance(env, ExternalEnv): if num_envs != 1: raise ValueError( - "ServingEnv does not currently support num_envs > 1.") - env = _ServingEnvToAsync(env) + "ExternalEnv does not currently support num_envs > 1.") + env = _ExternalEnvToAsync(env) elif isinstance(env, VectorEnv): env = _VectorEnvToAsync(env) else: @@ -139,36 +146,52 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID): return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()} -class _ServingEnvToAsync(AsyncVectorEnv): - """Internal adapter of ServingEnv to AsyncVectorEnv.""" +class _ExternalEnvToAsync(AsyncVectorEnv): + """Internal adapter of ExternalEnv to AsyncVectorEnv.""" - def __init__(self, serving_env): - self.serving_env = serving_env - serving_env.start() + def __init__(self, external_env, preprocessor=None): + self.external_env = external_env + self.prep = preprocessor + self.action_space = external_env.action_space + if preprocessor: + self.observation_space = preprocessor.observation_space + else: + self.observation_space = external_env.observation_space + external_env.start() + @override(AsyncVectorEnv) def poll(self): - with self.serving_env._results_avail_condition: + with self.external_env._results_avail_condition: results = self._poll() while len(results[0]) == 0: - self.serving_env._results_avail_condition.wait() + self.external_env._results_avail_condition.wait() results = self._poll() - if not self.serving_env.isAlive(): + if not self.external_env.isAlive(): raise Exception("Serving thread has stopped.") - limit = self.serving_env._max_concurrent_episodes + limit = self.external_env._max_concurrent_episodes assert len(results[0]) < limit, \ - ("Too many concurrent episodes, were some leaked? This ServingEnv " - "was created with max_concurrent={}".format(limit)) + ("Too many concurrent episodes, were some leaked? This " + "ExternalEnv was created with max_concurrent={}".format(limit)) return results + @override(AsyncVectorEnv) + def send_actions(self, action_dict): + for eid, action in action_dict.items(): + self.external_env._episodes[eid].action_queue.put( + action[_DUMMY_AGENT_ID]) + def _poll(self): all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {} off_policy_actions = {} - for eid, episode in self.serving_env._episodes.copy().items(): + for eid, episode in self.external_env._episodes.copy().items(): data = episode.get_data() if episode.cur_done: - del self.serving_env._episodes[eid] + del self.external_env._episodes[eid] if data: - all_obs[eid] = data["obs"] + if self.prep: + all_obs[eid] = self.prep.transform(data["obs"]) + else: + all_obs[eid] = data["obs"] all_rewards[eid] = data["reward"] all_dones[eid] = data["done"] all_infos[eid] = data["info"] @@ -180,11 +203,6 @@ def _poll(self): _with_dummy_agent_id(all_infos), \ _with_dummy_agent_id(off_policy_actions) - def send_actions(self, action_dict): - for eid, action in action_dict.items(): - self.serving_env._episodes[eid].action_queue.put( - action[_DUMMY_AGENT_ID]) - class _VectorEnvToAsync(AsyncVectorEnv): """Internal adapter of VectorEnv to AsyncVectorEnv. @@ -196,13 +214,18 @@ class _VectorEnvToAsync(AsyncVectorEnv): def __init__(self, vector_env): self.vector_env = vector_env + self.action_space = vector_env.action_space + self.observation_space = vector_env.observation_space self.num_envs = vector_env.num_envs - self.new_obs = self.vector_env.vector_reset() + self.new_obs = None # lazily initialized self.cur_rewards = [None for _ in range(self.num_envs)] self.cur_dones = [False for _ in range(self.num_envs)] self.cur_infos = [None for _ in range(self.num_envs)] + @override(AsyncVectorEnv) def poll(self): + if self.new_obs is None: + self.new_obs = self.vector_env.vector_reset() new_obs = dict(enumerate(self.new_obs)) rewards = dict(enumerate(self.cur_rewards)) dones = dict(enumerate(self.cur_dones)) @@ -216,6 +239,7 @@ def poll(self): _with_dummy_agent_id(dones, "__all__"), \ _with_dummy_agent_id(infos), {} + @override(AsyncVectorEnv) def send_actions(self, action_dict): action_vector = [None] * self.num_envs for i in range(self.num_envs): @@ -223,9 +247,11 @@ def send_actions(self, action_dict): self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \ self.vector_env.vector_step(action_vector) + @override(AsyncVectorEnv) def try_reset(self, env_id): return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)} + @override(AsyncVectorEnv) def get_unwrapped(self): return self.vector_env.get_unwrapped() @@ -256,24 +282,39 @@ def __init__(self, make_env, existing_envs, num_envs): assert isinstance(env, MultiAgentEnv) self.env_states = [_MultiAgentEnvState(env) for env in self.envs] + @override(AsyncVectorEnv) def poll(self): obs, rewards, dones, infos = {}, {}, {}, {} for i, env_state in enumerate(self.env_states): obs[i], rewards[i], dones[i], infos[i] = env_state.poll() return obs, rewards, dones, infos, {} + @override(AsyncVectorEnv) def send_actions(self, action_dict): for env_id, agent_dict in action_dict.items(): if env_id in self.dones: raise ValueError("Env {} is already done".format(env_id)) env = self.envs[env_id] obs, rewards, dones, infos = env.step(agent_dict) + assert isinstance(obs, dict), "Not a multi-agent obs" + assert isinstance(rewards, dict), "Not a multi-agent reward" + assert isinstance(dones, dict), "Not a multi-agent return" + assert isinstance(infos, dict), "Not a multi-agent info" + if set(obs.keys()) != set(rewards.keys()): + raise ValueError( + "Key set for obs and rewards must be the same: " + "{} vs {}".format(obs.keys(), rewards.keys())) + if set(obs.keys()) != set(infos.keys()): + raise ValueError("Key set for obs and infos must be the same: " + "{} vs {}".format(obs.keys(), infos.keys())) if dones["__all__"]: self.dones.add(env_id) self.env_states[env_id].observe(obs, rewards, dones, infos) + @override(AsyncVectorEnv) def try_reset(self, env_id): obs = self.env_states[env_id].reset() + assert isinstance(obs, dict), "Not a multi-agent obs" if obs is not None and env_id in self.dones: self.dones.remove(env_id) return obs diff --git a/python/ray/rllib/env/external_env.py b/python/ray/rllib/env/external_env.py new file mode 100644 index 0000000000000..e71c816256dcc --- /dev/null +++ b/python/ray/rllib/env/external_env.py @@ -0,0 +1,226 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import queue +import threading +import uuid + + +class ExternalEnv(threading.Thread): + """An environment that interfaces with external agents. + + Unlike simulator envs, control is inverted. The environment queries the + policy to obtain actions and logs observations and rewards for training. + This is in contrast to gym.Env, where the algorithm drives the simulation + through env.step() calls. + + You can use ExternalEnv as the backend for policy serving (by serving HTTP + requests in the run loop), for ingesting offline logs data (by reading + offline transitions in the run loop), or other custom use cases not easily + expressed through gym.Env. + + ExternalEnv supports both on-policy actions (through self.get_action()), + and off-policy actions (through self.log_action()). + + This env is thread-safe, but individual episodes must be executed serially. + + Attributes: + action_space (gym.Space): Action space. + observation_space (gym.Space): Observation space. + + Examples: + >>> register_env("my_env", lambda config: YourExternalEnv(config)) + >>> agent = DQNAgent(env="my_env") + >>> while True: + print(agent.train()) + """ + + def __init__(self, action_space, observation_space, max_concurrent=100): + """Initialize an external env. + + ExternalEnv subclasses must call this during their __init__. + + Arguments: + action_space (gym.Space): Action space of the env. + observation_space (gym.Space): Observation space of the env. + max_concurrent (int): Max number of active episodes to allow at + once. Exceeding this limit raises an error. + """ + + threading.Thread.__init__(self) + self.daemon = True + self.action_space = action_space + self.observation_space = observation_space + self._episodes = {} + self._finished = set() + self._results_avail_condition = threading.Condition() + self._max_concurrent_episodes = max_concurrent + + def run(self): + """Override this to implement the run loop. + + Your loop should continuously: + 1. Call self.start_episode(episode_id) + 2. Call self.get_action(episode_id, obs) + -or- + self.log_action(episode_id, obs, action) + 3. Call self.log_returns(episode_id, reward) + 4. Call self.end_episode(episode_id, obs) + 5. Wait if nothing to do. + + Multiple episodes may be started at the same time. + """ + raise NotImplementedError + + def start_episode(self, episode_id=None, training_enabled=True): + """Record the start of an episode. + + Arguments: + episode_id (str): Unique string id for the episode or None for + it to be auto-assigned. + training_enabled (bool): Whether to use experiences for this + episode to improve the policy. + + Returns: + episode_id (str): Unique string id for the episode. + """ + + if episode_id is None: + episode_id = uuid.uuid4().hex + + if episode_id in self._finished: + raise ValueError( + "Episode {} has already completed.".format(episode_id)) + + if episode_id in self._episodes: + raise ValueError( + "Episode {} is already started".format(episode_id)) + + self._episodes[episode_id] = _ExternalEnvEpisode( + episode_id, self._results_avail_condition, training_enabled) + + return episode_id + + def get_action(self, episode_id, observation): + """Record an observation and get the on-policy action. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + + Returns: + action (obj): Action from the env action space. + """ + + episode = self._get(episode_id) + return episode.wait_for_action(observation) + + def log_action(self, episode_id, observation, action): + """Record an observation and (off-policy) action taken. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + action (obj): Action for the observation. + """ + + episode = self._get(episode_id) + episode.log_action(observation, action) + + def log_returns(self, episode_id, reward, info=None): + """Record returns from the environment. + + The reward will be attributed to the previous action taken by the + episode. Rewards accumulate until the next action. If no reward is + logged before the next action, a reward of 0.0 is assumed. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + reward (float): Reward from the environment. + info (dict): Optional info dict. + """ + + episode = self._get(episode_id) + episode.cur_reward += reward + if info: + episode.cur_info = info or {} + + def end_episode(self, episode_id, observation): + """Record the end of an episode. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + """ + + episode = self._get(episode_id) + self._finished.add(episode.episode_id) + episode.done(observation) + + def _get(self, episode_id): + """Get a started episode or raise an error.""" + + if episode_id in self._finished: + raise ValueError( + "Episode {} has already completed.".format(episode_id)) + + if episode_id not in self._episodes: + raise ValueError("Episode {} not found.".format(episode_id)) + + return self._episodes[episode_id] + + +class _ExternalEnvEpisode(object): + """Tracked state for each active episode.""" + + def __init__(self, episode_id, results_avail_condition, training_enabled): + self.episode_id = episode_id + self.results_avail_condition = results_avail_condition + self.training_enabled = training_enabled + self.data_queue = queue.Queue() + self.action_queue = queue.Queue() + self.new_observation = None + self.new_action = None + self.cur_reward = 0.0 + self.cur_done = False + self.cur_info = {} + + def get_data(self): + if self.data_queue.empty(): + return None + return self.data_queue.get_nowait() + + def log_action(self, observation, action): + self.new_observation = observation + self.new_action = action + self._send() + self.action_queue.get(True, timeout=60.0) + + def wait_for_action(self, observation): + self.new_observation = observation + self._send() + return self.action_queue.get(True, timeout=60.0) + + def done(self, observation): + self.new_observation = observation + self.cur_done = True + self._send() + + def _send(self): + item = { + "obs": self.new_observation, + "reward": self.cur_reward, + "done": self.cur_done, + "info": self.cur_info, + } + if self.new_action is not None: + item["off_policy_action"] = self.new_action + if not self.training_enabled: + item["info"]["training_enabled"] = False + self.new_observation = None + self.new_action = None + self.cur_reward = 0.0 + with self.results_avail_condition: + self.data_queue.put_nowait(item) + self.results_avail_condition.notify() diff --git a/python/ray/rllib/env/multi_agent_env.py b/python/ray/rllib/env/multi_agent_env.py index 42f7cee8c0428..2e569230a2120 100644 --- a/python/ray/rllib/env/multi_agent_env.py +++ b/python/ray/rllib/env/multi_agent_env.py @@ -56,7 +56,7 @@ def step(self, action_dict): rewards (dict): Reward values for each ready agent. If the episode is just started, the value will be None. dones (dict): Done values for each ready agent. The special key - "__all__" is used to indicate env termination. + "__all__" (required) is used to indicate env termination. infos (dict): Info values for each ready agent. """ raise NotImplementedError diff --git a/python/ray/rllib/env/serving_env.py b/python/ray/rllib/env/serving_env.py index 0c1e3ec0dbfe4..cb976bf8041e9 100644 --- a/python/ray/rllib/env/serving_env.py +++ b/python/ray/rllib/env/serving_env.py @@ -2,219 +2,7 @@ from __future__ import division from __future__ import print_function -from six.moves import queue -import threading -import uuid +from ray.rllib.env.external_env import ExternalEnv - -class ServingEnv(threading.Thread): - """An environment that provides policy serving. - - Unlike simulator envs, control is inverted. The environment queries the - policy to obtain actions and logs observations and rewards for training. - This is in contrast to gym.Env, where the algorithm drives the simulation - through env.step() calls. - - You can use ServingEnv as the backend for policy serving (by serving HTTP - requests in the run loop), for ingesting offline logs data (by reading - offline transitions in the run loop), or other custom use cases not easily - expressed through gym.Env. - - ServingEnv supports both on-policy serving (through self.get_action()), and - off-policy serving (through self.log_action()). - - This env is thread-safe, but individual episodes must be executed serially. - - Examples: - >>> register_env("my_env", lambda config: YourServingEnv(config)) - >>> agent = DQNAgent(env="my_env") - >>> while True: - print(agent.train()) - """ - - def __init__(self, action_space, observation_space, max_concurrent=100): - """Initialize a serving env. - - ServingEnv subclasses must call this during their __init__. - - Arguments: - action_space (gym.Space): Action space of the env. - observation_space (gym.Space): Observation space of the env. - max_concurrent (int): Max number of active episodes to allow at - once. Exceeding this limit raises an error. - """ - - threading.Thread.__init__(self) - self.daemon = True - self.action_space = action_space - self.observation_space = observation_space - self._episodes = {} - self._finished = set() - self._results_avail_condition = threading.Condition() - self._max_concurrent_episodes = max_concurrent - - def run(self): - """Override this to implement the run loop. - - Your loop should continuously: - 1. Call self.start_episode() - 2. Call self.get_action() or self.log_action() - 3. Call self.log_returns() - 4. Call self.end_episode() - 5. Wait if nothing to do. - - Multiple episodes may be started at the same time. - """ - raise NotImplementedError - - def start_episode(self, episode_id=None, training_enabled=True): - """Record the start of an episode. - - Arguments: - episode_id (str): Unique string id for the episode or None for - it to be auto-assigned. - training_enabled (bool): Whether to use experiences for this - episode to improve the policy. - - Returns: - episode_id (str): Unique string id for the episode. - """ - - if episode_id is None: - episode_id = uuid.uuid4().hex - - if episode_id in self._finished: - raise ValueError( - "Episode {} has already completed.".format(episode_id)) - - if episode_id in self._episodes: - raise ValueError( - "Episode {} is already started".format(episode_id)) - - self._episodes[episode_id] = _ServingEnvEpisode( - episode_id, self._results_avail_condition, training_enabled) - - return episode_id - - def get_action(self, episode_id, observation): - """Record an observation and get the on-policy action. - - Arguments: - episode_id (str): Episode id returned from start_episode(). - observation (obj): Current environment observation. - - Returns: - action (obj): Action from the env action space. - """ - - episode = self._get(episode_id) - return episode.wait_for_action(observation) - - def log_action(self, episode_id, observation, action): - """Record an observation and (off-policy) action taken. - - Arguments: - episode_id (str): Episode id returned from start_episode(). - observation (obj): Current environment observation. - action (obj): Action for the observation. - """ - - episode = self._get(episode_id) - episode.log_action(observation, action) - - def log_returns(self, episode_id, reward, info=None): - """Record returns from the environment. - - The reward will be attributed to the previous action taken by the - episode. Rewards accumulate until the next action. If no reward is - logged before the next action, a reward of 0.0 is assumed. - - Arguments: - episode_id (str): Episode id returned from start_episode(). - reward (float): Reward from the environment. - info (dict): Optional info dict. - """ - - episode = self._get(episode_id) - episode.cur_reward += reward - if info: - episode.cur_info = info or {} - - def end_episode(self, episode_id, observation): - """Record the end of an episode. - - Arguments: - episode_id (str): Episode id returned from start_episode(). - observation (obj): Current environment observation. - """ - - episode = self._get(episode_id) - self._finished.add(episode.episode_id) - episode.done(observation) - - def _get(self, episode_id): - """Get a started episode or raise an error.""" - - if episode_id in self._finished: - raise ValueError( - "Episode {} has already completed.".format(episode_id)) - - if episode_id not in self._episodes: - raise ValueError("Episode {} not found.".format(episode_id)) - - return self._episodes[episode_id] - - -class _ServingEnvEpisode(object): - """Tracked state for each active episode.""" - - def __init__(self, episode_id, results_avail_condition, training_enabled): - self.episode_id = episode_id - self.results_avail_condition = results_avail_condition - self.training_enabled = training_enabled - self.data_queue = queue.Queue() - self.action_queue = queue.Queue() - self.new_observation = None - self.new_action = None - self.cur_reward = 0.0 - self.cur_done = False - self.cur_info = {} - - def get_data(self): - if self.data_queue.empty(): - return None - return self.data_queue.get_nowait() - - def log_action(self, observation, action): - self.new_observation = observation - self.new_action = action - self._send() - self.action_queue.get(True, timeout=60.0) - - def wait_for_action(self, observation): - self.new_observation = observation - self._send() - return self.action_queue.get(True, timeout=60.0) - - def done(self, observation): - self.new_observation = observation - self.cur_done = True - self._send() - - def _send(self): - item = { - "obs": self.new_observation, - "reward": self.cur_reward, - "done": self.cur_done, - "info": self.cur_info, - } - if self.new_action is not None: - item["off_policy_action"] = self.new_action - if not self.training_enabled: - item["info"]["training_enabled"] = False - self.new_observation = None - self.new_action = None - self.cur_reward = 0.0 - with self.results_avail_condition: - self.data_queue.put_nowait(item) - self.results_avail_condition.notify() +# renamed to ExternalEnv in 0.6 +ServingEnv = ExternalEnv diff --git a/python/ray/rllib/env/vector_env.py b/python/ray/rllib/env/vector_env.py index 7fb5b1605543e..c2eb1692061ce 100644 --- a/python/ray/rllib/env/vector_env.py +++ b/python/ray/rllib/env/vector_env.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +from ray.rllib.utils.annotations import override + class VectorEnv(object): """An environment that supports batch evaluation. @@ -69,13 +71,18 @@ def __init__(self, make_env, existing_envs, num_envs): self.num_envs = num_envs while len(self.envs) < self.num_envs: self.envs.append(self.make_env(len(self.envs))) + self.action_space = self.envs[0].action_space + self.observation_space = self.envs[0].observation_space + @override(VectorEnv) def vector_reset(self): return [e.reset() for e in self.envs] + @override(VectorEnv) def reset_at(self, index): return self.envs[index].reset() + @override(VectorEnv) def vector_step(self, actions): obs_batch, rew_batch, done_batch, info_batch = [], [], [], [] for i in range(self.num_envs): @@ -86,5 +93,6 @@ def vector_step(self, actions): info_batch.append(info) return obs_batch, rew_batch, done_batch, info_batch + @override(VectorEnv) def get_unwrapped(self): return self.envs diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index fc99d79fbb041..11977745184d5 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -7,13 +7,15 @@ import numpy as np +from ray.rllib.env.async_vector_env import _DUMMY_AGENT_ID + class MultiAgentEpisode(object): """Tracks the current state of a (possibly multi-agent) episode. The APIs in this class should be considered experimental, but we should avoid changing things for the sake of changing them since users may - depend on them for advanced algorithms. + depend on them for custom metrics or advanced algorithms. Attributes: new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder. @@ -23,6 +25,8 @@ class MultiAgentEpisode(object): length (int): Length of this episode. episode_id (int): Unique id identifying this trajectory. agent_rewards (dict): Summed rewards broken down by agent. + custom_metrics (dict): Dict where the you can add custom metrics. + user_data (dict): Dict that you can use for temporary storage. Use case 1: Model-based rollouts in multi-agent: A custom compute_actions() function in a policy graph can inspect the @@ -47,15 +51,22 @@ def __init__(self, policies, policy_mapping_fn, batch_builder_factory, self.length = 0 self.episode_id = random.randrange(2e9) self.agent_rewards = defaultdict(float) + self.custom_metrics = {} + self.user_data = {} self._policies = policies self._policy_mapping_fn = policy_mapping_fn + self._next_agent_index = 0 + self._agent_to_index = {} self._agent_to_policy = {} self._agent_to_rnn_state = {} self._agent_to_last_obs = {} + self._agent_to_last_info = {} self._agent_to_last_action = {} self._agent_to_last_pi_info = {} + self._agent_to_prev_action = {} + self._agent_reward_history = defaultdict(list) - def policy_for(self, agent_id): + def policy_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the policy graph for the specified agent. If the agent is new, the policy mapping fn will be called to bind the @@ -66,27 +77,46 @@ def policy_for(self, agent_id): self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id) return self._agent_to_policy[agent_id] - def last_observation_for(self, agent_id): + def last_observation_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the last observation for the specified agent.""" return self._agent_to_last_obs.get(agent_id) - def last_action_for(self, agent_id): - """Returns the last action for the specified agent.""" - - action = self._agent_to_last_action[agent_id] - # Concatenate tuple actions - if isinstance(action, list): - expanded = [] - for a in action: - if len(a.shape) == 1: - expanded.append(np.expand_dims(a, 1)) - else: - expanded.append(a) - action = np.concatenate(expanded, axis=1).flatten() - return action - - def rnn_state_for(self, agent_id): + def last_info_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the last info for the specified agent.""" + + return self._agent_to_last_info.get(agent_id) + + def last_action_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the last action for the specified agent, or zeros.""" + + if agent_id in self._agent_to_last_action: + return _flatten_action(self._agent_to_last_action[agent_id]) + else: + policy = self._policies[self.policy_for(agent_id)] + flat = _flatten_action(policy.action_space.sample()) + return np.zeros_like(flat) + + def prev_action_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the previous action for the specified agent.""" + + if agent_id in self._agent_to_prev_action: + return _flatten_action(self._agent_to_prev_action[agent_id]) + else: + # We're at t=0, so return all zeros. + return np.zeros_like(self.last_action_for(agent_id)) + + def prev_reward_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the previous reward for the specified agent.""" + + history = self._agent_reward_history[agent_id] + if len(history) >= 2: + return history[-2] + else: + # We're at t=0, so there is no previous reward, just return zero. + return 0.0 + + def rnn_state_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the last RNN state for the specified agent.""" if agent_id not in self._agent_to_rnn_state: @@ -94,7 +124,7 @@ def rnn_state_for(self, agent_id): self._agent_to_rnn_state[agent_id] = policy.get_initial_state() return self._agent_to_rnn_state[agent_id] - def last_pi_info_for(self, agent_id): + def last_pi_info_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the last info object for the specified agent.""" return self._agent_to_last_pi_info[agent_id] @@ -105,6 +135,7 @@ def _add_agent_rewards(self, reward_dict): self.agent_rewards[agent_id, self.policy_for(agent_id)] += reward self.total_reward += reward + self._agent_reward_history[agent_id].append(reward) def _set_rnn_state(self, agent_id, rnn_state): self._agent_to_rnn_state[agent_id] = rnn_state @@ -112,8 +143,30 @@ def _set_rnn_state(self, agent_id, rnn_state): def _set_last_observation(self, agent_id, obs): self._agent_to_last_obs[agent_id] = obs + def _set_last_info(self, agent_id, info): + self._agent_to_last_info[agent_id] = info + def _set_last_action(self, agent_id, action): self._agent_to_last_action[agent_id] = action def _set_last_pi_info(self, agent_id, pi_info): self._agent_to_last_pi_info[agent_id] = pi_info + + def _agent_index(self, agent_id): + if agent_id not in self._agent_to_index: + self._agent_to_index[agent_id] = self._next_agent_index + self._next_agent_index += 1 + return self._agent_to_index[agent_id] + + +def _flatten_action(action): + # Concatenate tuple actions + if isinstance(action, list) or isinstance(action, tuple): + expanded = [] + for a in action: + if not hasattr(a, "shape") or len(a.shape) == 0: + expanded.append(np.expand_dims(a, 1)) + else: + expanded.append(a) + action = np.concatenate(expanded, axis=0).flatten() + return action diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index dc71c4ecd1185..92c357d117e85 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -2,48 +2,69 @@ from __future__ import division from __future__ import print_function +import logging import numpy as np import collections import ray from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +logger = logging.getLogger(__name__) -def collect_metrics(local_evaluator, remote_evaluators=[]): + +def collect_metrics(local_evaluator, remote_evaluators=[], + timeout_seconds=180): """Gathers episode metrics from PolicyEvaluator instances.""" - episodes = collect_episodes(local_evaluator, remote_evaluators) - return summarize_episodes(episodes, episodes) + episodes, num_dropped = collect_episodes( + local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds) + metrics = summarize_episodes(episodes, episodes, num_dropped) + return metrics -def collect_episodes(local_evaluator, remote_evaluators=[]): +def collect_episodes(local_evaluator, + remote_evaluators=[], + timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" - metric_lists = ray.get([ + pending = [ a.apply.remote(lambda ev: ev.sampler.get_metrics()) for a in remote_evaluators - ]) + ] + collected, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout_seconds * 1000) + num_metric_batches_dropped = len(pending) - len(collected) + + metric_lists = ray.get(collected) metric_lists.append(local_evaluator.sampler.get_metrics()) episodes = [] for metrics in metric_lists: episodes.extend(metrics) - return episodes + return episodes, num_metric_batches_dropped -def summarize_episodes(episodes, new_episodes): +def summarize_episodes(episodes, new_episodes, num_dropped): """Summarizes a set of episode metrics tuples. Arguments: episodes: smoothed set of episodes including historical ones new_episodes: just the new episodes in this iteration + num_dropped: number of workers haven't returned their metrics """ + if num_dropped > 0: + logger.warn("WARNING: {} workers have NOT returned metrics".format( + num_dropped)) + episode_rewards = [] episode_lengths = [] policy_rewards = collections.defaultdict(list) + custom_metrics = collections.defaultdict(list) for episode in episodes: episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) + for k, v in episode.custom_metrics.items(): + custom_metrics[k].append(v) for (_, policy_id), reward in episode.agent_rewards.items(): if policy_id != DEFAULT_POLICY_ID: policy_rewards[policy_id].append(reward) @@ -59,10 +80,23 @@ def summarize_episodes(episodes, new_episodes): for policy_id, rewards in policy_rewards.copy().items(): policy_rewards[policy_id] = np.mean(rewards) + for k, v_list in custom_metrics.copy().items(): + custom_metrics[k + "_mean"] = np.mean(v_list) + filt = [v for v in v_list if not np.isnan(v)] + if filt: + custom_metrics[k + "_min"] = np.min(filt) + custom_metrics[k + "_max"] = np.max(filt) + else: + custom_metrics[k + "_min"] = float("nan") + custom_metrics[k + "_max"] = float("nan") + del custom_metrics[k] + return dict( episode_reward_max=max_reward, episode_reward_min=min_reward, episode_reward_mean=avg_reward, episode_len_mean=avg_length, episodes_this_iter=len(new_episodes), - policy_reward_mean=dict(policy_rewards)) + policy_reward_mean=dict(policy_rewards), + custom_metrics=dict(custom_metrics), + num_metric_batches_dropped=num_dropped) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 24eb746100d56..b97f6a27bd00c 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -3,27 +3,31 @@ from __future__ import print_function import gym +import logging import pickle import tensorflow as tf import ray -from ray.rllib.models import ModelCatalog from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.serving_env import ServingEnv -from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.interface import EvaluatorInterface from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \ DEFAULT_POLICY_ID from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler -from ray.rllib.utils.compression import pack -from ray.rllib.utils.filter import get_filter from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.models import ModelCatalog +from ray.rllib.models.preprocessors import NoPreprocessor +from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override +from ray.rllib.utils.compression import pack +from ray.rllib.utils.filter import get_filter from ray.rllib.utils.tf_run_builder import TFRunBuilder +logger = logging.getLogger(__name__) + class PolicyEvaluator(EvaluatorInterface): """Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``. @@ -71,7 +75,7 @@ class PolicyEvaluator(EvaluatorInterface): ... policy_mapping_fn=lambda agent_id: ... random.choice(["car_policy1", "car_policy2"]) ... if agent_id.startswith("car_") else "traffic_light_policy") - >>> print(evaluator.sample().keys()) + >>> print(evaluator.sample()) MultiAgentBatch({ "car_policy1": SampleBatch(...), "car_policy2": SampleBatch(...), @@ -79,8 +83,9 @@ class PolicyEvaluator(EvaluatorInterface): """ @classmethod - def as_remote(cls, num_cpus=None, num_gpus=None): - return ray.remote(num_cpus=num_cpus, num_gpus=num_gpus)(cls) + def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): + return ray.remote( + num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls) def __init__(self, env_creator, @@ -97,11 +102,14 @@ def __init__(self, num_envs=1, observation_filter="NoFilter", clip_rewards=None, + clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, - monitor_path=None): + monitor_path=None, + log_level=None, + callbacks=None): """Initialize a policy evaluator. Arguments: @@ -124,16 +132,14 @@ def __init__(self, in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch - of at most `batch_steps` in size. The batch will be exactly - `batch_steps` in size if postprocessing does not change - batch sizes. Episodes may be truncated in order to meet - this size requirement. When `num_envs > 1`, episodes will - be truncated to sequences of `batch_size / num_envs` in - length. + of at most `batch_steps * num_envs` in size. The batch will + be exactly `batch_steps * num_envs` in size if + postprocessing does not change batch sizes. Episodes may be + truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch - of at least `batch_steps in size. Episodes will not be - truncated, but multiple episodes may be packed within one - batch to meet the batch size. Note that when + of at least `batch_steps * num_envs` in size. Episodes will + not be truncated, but multiple episodes may be packed + within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. @@ -152,6 +158,8 @@ def __init__(self, clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. + clip_actions (bool): Whether to clip action values to the range + specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the @@ -162,48 +170,58 @@ def __init__(self, through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. + log_level (str): Set the root log level on creation. + callbacks (dict): Dict of custom debug callbacks. """ + if log_level: + logging.getLogger("ray.rllib").setLevel(log_level) + env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config + self.callbacks = callbacks or {} model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) + if not callable(policy_mapping_fn): + raise ValueError( + "Policy mapping function not callable. If you're using Tune, " + "make sure to escape the function with tune.function() " + "to prevent it from being evaluated as an expression.") self.env_creator = env_creator - self.batch_steps = batch_steps + self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations + self.preprocessing_enabled = True self.env = env_creator(env_context) - if isinstance(self.env, VectorEnv) or \ - isinstance(self.env, ServingEnv) or \ - isinstance(self.env, MultiAgentEnv) or \ + if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, AsyncVectorEnv): def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ - "custom_preprocessor" not in model_config and \ + not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": + # Deepmind wrappers already handle all preprocessing + self.preprocessing_enabled = False + if clip_rewards is None: clip_rewards = True def wrap(env): env = wrap_deepmind( env, - dim=model_config.get("dim", 84), - framestack=not model_config.get("use_lstm") - and not model_config.get("no_framestack")) + dim=model_config.get("dim"), + framestack=model_config.get("framestack")) if monitor_path: env = _monitor(env, monitor_path) return env else: def wrap(env): - env = ModelCatalog.get_preprocessor_as_wrapper( - env, model_config) if monitor_path: env = _monitor(env, monitor_path) return env @@ -226,13 +244,20 @@ def make_env(vector_index): config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): - self.policy_map = self._build_policy_map( - policy_dict, policy_config) + self.policy_map, self.preprocessors = \ + self._build_policy_map(policy_dict, policy_config) else: - self.policy_map = self._build_policy_map(policy_dict, - policy_config) + self.policy_map, self.preprocessors = self._build_policy_map( + policy_dict, policy_config) - self.multiagent = self.policy_map.keys() != {DEFAULT_POLICY_ID} + self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} + if self.multiagent: + if not (isinstance(self.env, MultiAgentEnv) + or isinstance(self.env, AsyncVectorEnv)): + raise ValueError( + "Have multiple policy graphs {}, but the env ".format( + self.policy_map) + + "{} is not a subclass of MultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, @@ -246,15 +271,10 @@ def make_env(vector_index): self.num_envs = num_envs if self.batch_mode == "truncate_episodes": - if batch_steps % num_envs != 0: - raise ValueError( - "In 'truncate_episodes' batch mode, `batch_steps` must be " - "evenly divisible by `num_envs`. Got {} and {}.".format( - batch_steps, num_envs)) - batch_steps = batch_steps // num_envs + unroll_length = batch_steps pack_episodes = True elif self.batch_mode == "complete_episodes": - batch_steps = float("inf") # never cut episodes + unroll_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( @@ -264,35 +284,35 @@ def make_env(vector_index): self.async_env, self.policy_map, policy_mapping_fn, + self.preprocessors, self.filters, clip_rewards, - batch_steps, + unroll_length, + self.callbacks, horizon=episode_horizon, pack=pack_episodes, - tf_sess=self.tf_sess) + tf_sess=self.tf_sess, + clip_actions=clip_actions) self.sampler.start() else: self.sampler = SyncSampler( self.async_env, self.policy_map, policy_mapping_fn, + self.preprocessors, self.filters, clip_rewards, - batch_steps, + unroll_length, + self.callbacks, horizon=episode_horizon, pack=pack_episodes, - tf_sess=self.tf_sess) + tf_sess=self.tf_sess, + clip_actions=clip_actions) - def _build_policy_map(self, policy_dict, policy_config): - policy_map = {} - for name, (cls, obs_space, act_space, - conf) in sorted(policy_dict.items()): - merged_conf = policy_config.copy() - merged_conf.update(conf) - with tf.variable_scope(name): - policy_map[name] = cls(obs_space, act_space, merged_conf) - return policy_map + logger.debug("Created evaluator with env {} ({}), policies {}".format( + self.async_env, self.env, self.policy_map)) + @override(EvaluatorInterface) def sample(self): """Evaluate the current policies and return a batch of experiences. @@ -310,13 +330,20 @@ def sample(self): else: max_batches = float("inf") - while steps_so_far < self.batch_steps and len(batches) < max_batches: + while steps_so_far < self.sample_batch_size and len( + batches) < max_batches: batch = self.sampler.get_data() steps_so_far += batch.count batches.append(batch) batches.extend(self.sampler.get_extra_batches()) batch = batches[0].concat_samples(batches) + if self.callbacks.get("on_sample_end"): + self.callbacks["on_sample_end"]({ + "evaluator": self, + "samples": batch + }) + if self.compress_observations: if isinstance(batch, MultiAgentBatch): for data in batch.policy_batches.values(): @@ -334,52 +361,7 @@ def sample_with_count(self): batch = self.sample() return batch, batch.count - def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): - """Apply the given function to the specified policy graph.""" - - return func(self.policy_map[policy_id]) - - def foreach_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple.""" - - return [func(policy, pid) for pid, policy in self.policy_map.items()] - - def foreach_trainable_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple. - - This only applies func to policies in `self.policies_to_train`.""" - - return [ - func(policy, pid) for pid, policy in self.policy_map.items() - if pid in self.policies_to_train - ] - - def sync_filters(self, new_filters): - """Changes self's filter to given and rebases any accumulated delta. - - Args: - new_filters (dict): Filters with new state to update local copy. - """ - assert all(k in new_filters for k in self.filters) - for k in self.filters: - self.filters[k].sync(new_filters[k]) - - def get_filters(self, flush_after=False): - """Returns a snapshot of filters. - - Args: - flush_after (bool): Clears the filter buffer state. - - Returns: - return_filters (dict): Dict for serializable filters - """ - return_filters = {} - for k, f in self.filters.items(): - return_filters[k] = f.as_serializable() - if flush_after: - f.clear_buffer() - return return_filters - + @override(EvaluatorInterface) def get_weights(self, policies=None): if policies is None: policies = self.policy_map.keys() @@ -388,10 +370,12 @@ def get_weights(self, policies=None): for pid, policy in self.policy_map.items() if pid in policies } + @override(EvaluatorInterface) def set_weights(self, weights): for pid, w in weights.items(): self.policy_map[pid].set_weights(w) + @override(EvaluatorInterface) def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} @@ -401,12 +385,14 @@ def compute_gradients(self, samples): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( - self.policy_map[pid].build_compute_gradients( + self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) else: @@ -415,12 +401,13 @@ def compute_gradients(self, samples): info_out["batch_count"] = samples.count return grad_out, info_out + @override(EvaluatorInterface) def apply_gradients(self, grads): if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { - pid: self.policy_map[pid].build_apply_gradients( + pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } @@ -433,6 +420,7 @@ def apply_gradients(self, grads): else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) + @override(EvaluatorInterface) def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} @@ -442,11 +430,13 @@ def compute_apply(self, samples): if pid not in self.policies_to_train: continue info_out[pid], _ = ( - self.policy_map[pid].build_compute_apply( + self.policy_map[pid]._build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue info_out[pid], _ = ( self.policy_map[pid].compute_apply(batch)) return info_out @@ -455,6 +445,52 @@ def compute_apply(self, samples): self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): + """Apply the given function to the specified policy graph.""" + + return func(self.policy_map[policy_id]) + + def foreach_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple.""" + + return [func(policy, pid) for pid, policy in self.policy_map.items()] + + def foreach_trainable_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple. + + This only applies func to policies in `self.policies_to_train`.""" + + return [ + func(policy, pid) for pid, policy in self.policy_map.items() + if pid in self.policies_to_train + ] + + def sync_filters(self, new_filters): + """Changes self's filter to given and rebases any accumulated delta. + + Args: + new_filters (dict): Filters with new state to update local copy. + """ + assert all(k in new_filters for k in self.filters) + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + """Returns a snapshot of filters. + + Args: + flush_after (bool): Clears the filter buffer state. + + Returns: + return_filters (dict): Dict for serializable filters + """ + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters + def save(self): filters = self.get_filters(flush_after=True) state = { @@ -472,6 +508,29 @@ def restore(self, objs): def set_global_vars(self, global_vars): self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) + def _build_policy_map(self, policy_dict, policy_config): + policy_map = {} + preprocessors = {} + for name, (cls, obs_space, act_space, + conf) in sorted(policy_dict.items()): + merged_conf = merge_dicts(policy_config, conf) + if self.preprocessing_enabled: + preprocessor = ModelCatalog.get_preprocessor_for_space( + obs_space, merged_conf.get("model")) + preprocessors[name] = preprocessor + obs_space = preprocessor.observation_space + else: + preprocessors[name] = NoPreprocessor(obs_space) + if isinstance(obs_space, gym.spaces.Dict) or \ + isinstance(obs_space, gym.spaces.Tuple): + raise ValueError( + "Found raw Tuple|Dict space as input to policy graph. " + "Please preprocess these observations with a " + "Tuple|DictFlatteningPreprocessor.") + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + return policy_map, preprocessors + def _validate_and_canonicalize(policy_graph, env): if isinstance(policy_graph, dict): @@ -503,6 +562,11 @@ def _validate_and_canonicalize(policy_graph, env): elif not issubclass(policy_graph, PolicyGraph): raise ValueError("policy_graph must be a rllib.PolicyGraph class") else: + if (isinstance(env, MultiAgentEnv) + and not hasattr(env, "observation_space")): + raise ValueError( + "MultiAgentEnv must have observation_space defined if run " + "in a single-agent configuration.") return { DEFAULT_POLICY_ID: (policy_graph, env.observation_space, env.action_space, {}) diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 925fa70aa1545..c19da286b0b9a 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -40,14 +40,16 @@ class you pass into PolicyEvaluator will be constructed with def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): """Compute actions for the current policy. Arguments: obs_batch (np.ndarray): batch of observations state_batches (list): list of RNN state input batches, if any - is_training (bool): whether we are training the policy + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards episodes (list): MultiAgentEpisode for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multiagent algorithms. @@ -65,17 +67,19 @@ def compute_actions(self, def compute_single_action(self, obs, state, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episode=None): """Unbatched version of compute_actions. Arguments: obs (obj): single observation state_batches (list): list of RNN state inputs, if any - is_training (bool): whether we are training the policy + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards episode (MultiAgentEpisode): this provides access to all of the internal episode state, which may be useful for model-based or - multiagent algorithms. + multi-agent algorithms. Returns: actions (obj): single action @@ -84,11 +88,14 @@ def compute_single_action(self, """ [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], is_training, episodes=[episode]) + [obs], [[s] for s in state], episodes=[episode]) return action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} - def postprocess_trajectory(self, sample_batch, other_agent_batches=None): + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): """Implements algorithm-specific trajectory postprocessing. This will be called on each trajectory fragment computed during policy @@ -100,6 +107,9 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): other_agent_batches (dict): In a multi-agent env, this contains a mapping of agent ids to (policy_graph, agent_batch) tuples containing the policy graph and experiences of the other agent. + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multi-agent algorithms. Returns: SampleBatch: postprocessed sample batch. diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index f8f88a4aae314..5a0099530705f 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -79,6 +79,11 @@ def __init__(self, policy_map, clip_rewards): self.agent_to_policy = {} self.count = 0 # increment this manually + def total(self): + """Returns summed number of steps across all agent buffers.""" + + return sum(p.count for p in self.policy_builders.values()) + def has_pending_data(self): """Returns whether there is pending unprocessed data.""" @@ -99,11 +104,14 @@ def add_values(self, agent_id, policy_id, **values): builder = self.agent_builders[agent_id] builder.add_values(**values) - def postprocess_batch_so_far(self): + def postprocess_batch_so_far(self, episode): """Apply policy postprocessors to any unprocessed rows. This pushes the postprocessed per-agent batches onto the per-policy builders, clearing per-agent state. + + Arguments: + episode: current MultiAgentEpisode object or None """ # Materialize the batches so far @@ -128,7 +136,7 @@ def postprocess_batch_so_far(self): "Batches sent to postprocessing must only contain steps " "from a single trajectory.", pre_batch) post_batches[agent_id] = policy.postprocess_trajectory( - pre_batch, other_batches) + pre_batch, other_batches, episode) # Append into policy batches and reset for agent_id, post_batch in sorted(post_batches.items()): @@ -137,14 +145,17 @@ def postprocess_batch_so_far(self): self.agent_builders.clear() self.agent_to_policy.clear() - def build_and_reset(self): + def build_and_reset(self, episode): """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset. + + Arguments: + episode: current MultiAgentEpisode object or None """ - self.postprocess_batch_so_far() + self.postprocess_batch_so_far(episode) policy_batches = {} for policy_id, builder in self.policy_builders.items(): if builder.count > 0: @@ -189,6 +200,11 @@ def concat_samples(samples): out[policy_id] = SampleBatch.concat_samples(batches) return MultiAgentBatch(out, total_count) + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + def total(self): ct = 0 for batch in self.policy_batches.values(): @@ -250,6 +266,11 @@ def concat(self, other): out[k] = np.concatenate([self[k], other[k]]) return SampleBatch(out) + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + def rows(self): """Returns an iterator over data rows, i.e. dicts with column values. diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index f41c3ca739e21..ac7c6ed8a7ea8 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -2,23 +2,32 @@ from __future__ import division from __future__ import print_function +import gym from collections import defaultdict, namedtuple +import logging +import numpy as np import six.moves.queue as queue import threading -from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \ MultiAgentBatch from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv +from ray.rllib.models.action_dist import TupleActions from ray.rllib.utils.tf_run_builder import TFRunBuilder +logger = logging.getLogger(__name__) +_large_batch_warned = False + RolloutMetrics = namedtuple( - "RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"]) + "RolloutMetrics", + ["episode_length", "episode_reward", "agent_rewards", "custom_metrics"]) -PolicyEvalData = namedtuple("PolicyEvalData", - ["env_id", "agent_id", "obs", "rnn_state"]) +PolicyEvalData = namedtuple( + "PolicyEvalData", + ["env_id", "agent_id", "obs", "rnn_state", "prev_action", "prev_reward"]) class SyncSampler(object): @@ -34,23 +43,28 @@ def __init__(self, env, policies, policy_mapping_fn, + preprocessors, obs_filters, clip_rewards, - num_local_steps, + unroll_length, + callbacks, horizon=None, pack=False, - tf_sess=None): + tf_sess=None, + clip_actions=True): self.async_vector_env = AsyncVectorEnv.wrap_async(env) - self.num_local_steps = num_local_steps + self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn - self._obs_filters = obs_filters + self.preprocessors = preprocessors + self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.rollout_provider = _env_runner( self.async_vector_env, self.extra_batches.put, self.policies, - self.policy_mapping_fn, self.num_local_steps, self.horizon, - self._obs_filters, clip_rewards, pack, tf_sess) + self.policy_mapping_fn, self.unroll_length, self.horizon, + self.preprocessors, self.obs_filters, clip_rewards, clip_actions, + pack, callbacks, tf_sess) self.metrics_queue = queue.Queue() def get_data(self): @@ -90,12 +104,15 @@ def __init__(self, env, policies, policy_mapping_fn, + preprocessors, obs_filters, clip_rewards, - num_local_steps, + unroll_length, + callbacks, horizon=None, pack=False, - tf_sess=None): + tf_sess=None, + clip_actions=True): for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." @@ -104,15 +121,18 @@ def __init__(self, self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() - self.num_local_steps = num_local_steps + self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn - self._obs_filters = obs_filters + self.preprocessors = preprocessors + self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True self.pack = pack self.tf_sess = tf_sess + self.callbacks = callbacks + self.clip_actions = clip_actions def run(self): try: @@ -124,8 +144,9 @@ def run(self): def _run(self): rollout_provider = _env_runner( self.async_vector_env, self.extra_batches.put, self.policies, - self.policy_mapping_fn, self.num_local_steps, self.horizon, - self._obs_filters, self.clip_rewards, self.pack, self.tf_sess) + self.policy_mapping_fn, self.unroll_length, self.horizon, + self.preprocessors, self.obs_filters, self.clip_rewards, + self.clip_actions, self.pack, self.callbacks, self.tf_sess) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -182,11 +203,14 @@ def _env_runner(async_vector_env, extra_batch_callback, policies, policy_mapping_fn, - num_local_steps, + unroll_length, horizon, + preprocessors, obs_filters, clip_rewards, + clip_actions, pack, + callbacks, tf_sess=None): """This implements the common experience collection logic. @@ -197,14 +221,18 @@ def _env_runner(async_vector_env, policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. - num_local_steps (int): Number of episode steps before `SampleBatch` is + unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. + preprocessors (dict): Map of policy id to preprocessor for the + observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This - guarantees batches will be exactly `num_local_steps` in size. + guarantees batches will be exactly `unroll_length` in size. + clip_actions (bool): Whether to clip actions to the space range. + callbacks (dict): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. @@ -218,7 +246,7 @@ def _env_runner(async_vector_env, horizon = ( async_vector_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: - print("Warning, no horizon specified, assuming infinite") + logger.warn("no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") @@ -233,8 +261,14 @@ def get_batch_builder(): return MultiAgentSampleBatchBuilder(policies, clip_rewards) def new_episode(): - return MultiAgentEpisode(policies, policy_mapping_fn, - get_batch_builder, extra_batch_callback) + episode = MultiAgentEpisode(policies, policy_mapping_fn, + get_batch_builder, extra_batch_callback) + if callbacks.get("on_episode_start"): + callbacks["on_episode_start"]({ + "env": async_vector_env, + "episode": episode + }) + return episode active_episodes = defaultdict(new_episode) @@ -243,152 +277,267 @@ def new_episode(): unfiltered_obs, rewards, dones, infos, off_policy_actions = \ async_vector_env.poll() - # Map of policy_id to list of PolicyEvalData - to_eval = defaultdict(list) + # Process observations and prepare for policy evaluation + active_envs, to_eval, outputs = _process_observations( + async_vector_env, policies, batch_builder_pool, active_episodes, + unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, + preprocessors, obs_filters, unroll_length, pack, callbacks) + for o in outputs: + yield o - # Map of env_id -> agent_id -> action replies - actions_to_send = defaultdict(dict) + # Do batched policy eval + eval_results = _do_policy_eval(tf_sess, to_eval, policies, + active_episodes, clip_actions) - # For each environment - for env_id, agent_obs in unfiltered_obs.items(): - new_episode = env_id not in active_episodes - episode = active_episodes[env_id] - if not new_episode: - episode.length += 1 - episode.batch_builder.count += 1 - episode._add_agent_rewards(rewards[env_id]) - - # Check episode termination conditions - if dones[env_id]["__all__"] or episode.length >= horizon: - all_done = True - atari_metrics = _fetch_atari_metrics(async_vector_env) - if atari_metrics is not None: - for m in atari_metrics: - yield m - else: - yield RolloutMetrics(episode.length, episode.total_reward, - dict(episode.agent_rewards)) + # Process results and update episode state + actions_to_send = _process_policy_eval_results( + to_eval, eval_results, active_episodes, active_envs, + off_policy_actions) + + # Return computed actions to ready envs. We also send to envs that have + # taken off-policy actions; those envs are free to ignore the action. + async_vector_env.send_actions(actions_to_send) + + +def _process_observations(async_vector_env, policies, batch_builder_pool, + active_episodes, unfiltered_obs, rewards, dones, + infos, off_policy_actions, horizon, preprocessors, + obs_filters, unroll_length, pack, callbacks): + """Record new data from the environment and prepare for policy evaluation. + + Returns: + active_envs: set of non-terminated env ids + to_eval: map of policy_id to list of agent PolicyEvalData + outputs: list of metrics and samples to return from the sampler + """ + + active_envs = set() + to_eval = defaultdict(list) + outputs = [] + + # For each environment + for env_id, agent_obs in unfiltered_obs.items(): + new_episode = env_id not in active_episodes + episode = active_episodes[env_id] + if not new_episode: + episode.length += 1 + episode.batch_builder.count += 1 + episode._add_agent_rewards(rewards[env_id]) + + global _large_batch_warned + if (not _large_batch_warned and + episode.batch_builder.total() > max(1000, unroll_length * 10)): + _large_batch_warned = True + logger.warn( + "More than {} observations for {} env steps ".format( + episode.batch_builder.total(), + episode.batch_builder.count) + "are buffered in " + "the sampler. If this is not intentional, check that the " + "the `horizon` config is set correctly, or consider setting " + "`batch_mode` to 'truncate_episodes'. Note that in " + "multi-agent environments, `sample_batch_size` sets the " + "batch size based on environment steps, not the steps of " + "individual agents.") + + # Check episode termination conditions + if dones[env_id]["__all__"] or episode.length >= horizon: + all_done = True + atari_metrics = _fetch_atari_metrics(async_vector_env) + if atari_metrics is not None: + for m in atari_metrics: + outputs.append( + m._replace(custom_metrics=episode.custom_metrics)) + else: + outputs.append( + RolloutMetrics(episode.length, episode.total_reward, + dict(episode.agent_rewards), + episode.custom_metrics)) + else: + all_done = False + active_envs.add(env_id) + + # For each agent in the environment + for agent_id, raw_obs in agent_obs.items(): + policy_id = episode.policy_for(agent_id) + prep_obs = _get_or_raise(preprocessors, + policy_id).transform(raw_obs) + filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) + agent_done = bool(all_done or dones[env_id].get(agent_id)) + if not agent_done: + to_eval[policy_id].append( + PolicyEvalData(env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id), + episode.last_action_for(agent_id), + rewards[env_id][agent_id] or 0.0)) + + last_observation = episode.last_observation_for(agent_id) + episode._set_last_observation(agent_id, filtered_obs) + episode._set_last_info(agent_id, infos[env_id][agent_id]) + + # Record transition info if applicable + if last_observation is not None and \ + infos[env_id][agent_id].get("training_enabled", True): + episode.batch_builder.add_values( + agent_id, + policy_id, + t=episode.length - 1, + eps_id=episode.episode_id, + agent_index=episode._agent_index(agent_id), + obs=last_observation, + actions=episode.last_action_for(agent_id), + rewards=rewards[env_id][agent_id], + prev_actions=episode.prev_action_for(agent_id), + prev_rewards=episode.prev_reward_for(agent_id), + dones=agent_done, + infos=infos[env_id][agent_id], + new_obs=filtered_obs, + **episode.last_pi_info_for(agent_id)) + + # Invoke the step callback after the step is logged to the episode + if callbacks.get("on_episode_step"): + callbacks["on_episode_step"]({ + "env": async_vector_env, + "episode": episode + }) + + # Cut the batch if we're not packing multiple episodes into one, + # or if we've exceeded the requested batch size. + if episode.batch_builder.has_pending_data(): + if (all_done and not pack) or \ + episode.batch_builder.count >= unroll_length: + outputs.append(episode.batch_builder.build_and_reset(episode)) + elif all_done: + # Make sure postprocessor stays within one episode + episode.batch_builder.postprocess_batch_so_far(episode) + + if all_done: + # Handle episode termination + batch_builder_pool.append(episode.batch_builder) + if callbacks.get("on_episode_end"): + callbacks["on_episode_end"]({ + "env": async_vector_env, + "episode": episode + }) + del active_episodes[env_id] + resetted_obs = async_vector_env.try_reset(env_id) + if resetted_obs is None: + # Reset not supported, drop this env from the ready list + if horizon != float("inf"): + raise ValueError( + "Setting episode horizon requires reset() support " + "from the environment.") else: - all_done = False - # At least send an empty dict if not done - actions_to_send[env_id] = {} - - # For each agent in the environment - for agent_id, raw_obs in agent_obs.items(): - policy_id = episode.policy_for(agent_id) - filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) - agent_done = bool(all_done or dones[env_id].get(agent_id)) - if not agent_done: + # Creates a new episode + episode = active_episodes[env_id] + for agent_id, raw_obs in resetted_obs.items(): + policy_id = episode.policy_for(agent_id) + policy = _get_or_raise(policies, policy_id) + prep_obs = _get_or_raise(preprocessors, + policy_id).transform(raw_obs) + filtered_obs = _get_or_raise(obs_filters, + policy_id)(prep_obs) + episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( - PolicyEvalData(env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id))) - - last_observation = episode.last_observation_for(agent_id) - episode._set_last_observation(agent_id, filtered_obs) - - # Record transition info if applicable - if last_observation is not None and \ - infos[env_id][agent_id].get("training_enabled", True): - episode.batch_builder.add_values( - agent_id, - policy_id, - t=episode.length - 1, - eps_id=episode.episode_id, - obs=last_observation, - actions=episode.last_action_for(agent_id), - rewards=rewards[env_id][agent_id], - dones=agent_done, - infos=infos[env_id][agent_id], - new_obs=filtered_obs, - **episode.last_pi_info_for(agent_id)) - - # Cut the batch if we're not packing multiple episodes into one, - # or if we've exceeded the requested batch size. - if episode.batch_builder.has_pending_data(): - if (all_done and not pack) or \ - episode.batch_builder.count >= num_local_steps: - yield episode.batch_builder.build_and_reset() - elif all_done: - # Make sure postprocessor stays within one episode - episode.batch_builder.postprocess_batch_so_far() - - if all_done: - # Handle episode termination - batch_builder_pool.append(episode.batch_builder) - del active_episodes[env_id] - resetted_obs = async_vector_env.try_reset(env_id) - if resetted_obs is None: - # Reset not supported, drop this env from the ready list - assert horizon == float("inf"), \ - "Setting episode horizon requires reset() support." - else: - # Creates a new episode - episode = active_episodes[env_id] - for agent_id, raw_obs in resetted_obs.items(): - policy_id = episode.policy_for(agent_id) - filtered_obs = _get_or_raise(obs_filters, - policy_id)(raw_obs) - episode._set_last_observation(agent_id, filtered_obs) - to_eval[policy_id].append( - PolicyEvalData(env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id))) - - # Batch eval policy actions if possible - if tf_sess: - builder = TFRunBuilder(tf_sess, "policy_eval") - pending_fetches = {} + PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id), + np.zeros_like( + _flatten_action(policy.action_space.sample())), + 0.0)) + + return active_envs, to_eval, outputs + + +def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions): + """Call compute actions on observation batches to get next actions. + + Returns: + eval_results: dict of policy to compute_action() outputs. + """ + + eval_results = {} + + if tf_sess: + builder = TFRunBuilder(tf_sess, "policy_eval") + pending_fetches = {} + else: + builder = None + for policy_id, eval_data in to_eval.items(): + rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) + policy = _get_or_raise(policies, policy_id) + if builder and (policy.compute_actions.__code__ is + TFPolicyGraph.compute_actions.__code__): + pending_fetches[policy_id] = policy._build_compute_actions( + builder, [t.obs for t in eval_data], + rnn_in_cols, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data]) else: - builder = None - eval_results = {} - rnn_in_cols = {} - for policy_id, eval_data in to_eval.items(): - rnn_in = _to_column_format([t.rnn_state for t in eval_data]) - rnn_in_cols[policy_id] = rnn_in + eval_results[policy_id] = policy.compute_actions( + [t.obs for t in eval_data], + rnn_in_cols, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data], + episodes=[active_episodes[t.env_id] for t in eval_data]) + if builder: + for k, v in pending_fetches.items(): + eval_results[k] = builder.get(v) + + if clip_actions: + for policy_id, results in eval_results.items(): policy = _get_or_raise(policies, policy_id) - if builder and (policy.compute_actions.__code__ is - TFPolicyGraph.compute_actions.__code__): - pending_fetches[policy_id] = policy.build_compute_actions( - builder, [t.obs for t in eval_data], - rnn_in, - is_training=True) + actions, rnn_out_cols, pi_info_cols = results + eval_results[policy_id] = (_clip_actions( + actions, policy.action_space), rnn_out_cols, pi_info_cols) + + return eval_results + + +def _process_policy_eval_results(to_eval, eval_results, active_episodes, + active_envs, off_policy_actions): + """Process the output of policy neural network evaluation. + + Records policy evaluation results into the given episode objects and + returns replies to send back to agents in the env. + + Returns: + actions_to_send: nested dict of env id -> agent id -> agent replies. + """ + + actions_to_send = defaultdict(dict) + for env_id in active_envs: + actions_to_send[env_id] = {} # at minimum send empty dict + + for policy_id, eval_data in to_eval.items(): + rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) + actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] + if len(rnn_in_cols) != len(rnn_out_cols): + raise ValueError("Length of RNN in did not match RNN out, got: " + "{} vs {}".format(rnn_in_cols, rnn_out_cols)) + # Add RNN state info + for f_i, column in enumerate(rnn_in_cols): + pi_info_cols["state_in_{}".format(f_i)] = column + for f_i, column in enumerate(rnn_out_cols): + pi_info_cols["state_out_{}".format(f_i)] = column + # Save output rows + actions = _unbatch_tuple_actions(actions) + for i, action in enumerate(actions): + env_id = eval_data[i].env_id + agent_id = eval_data[i].agent_id + actions_to_send[env_id][agent_id] = action + episode = active_episodes[env_id] + episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode._set_last_pi_info( + agent_id, {k: v[i] + for k, v in pi_info_cols.items()}) + if env_id in off_policy_actions and \ + agent_id in off_policy_actions[env_id]: + episode._set_last_action(agent_id, + off_policy_actions[env_id][agent_id]) else: - eval_results[policy_id] = policy.compute_actions( - [t.obs for t in eval_data], - rnn_in, - is_training=True, - episodes=[active_episodes[t.env_id] for t in eval_data]) - if builder: - for k, v in pending_fetches.items(): - eval_results[k] = builder.get(v) - - # Record the policy eval results - for policy_id, eval_data in to_eval.items(): - actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] - # Add RNN state info - for f_i, column in enumerate(rnn_in_cols[policy_id]): - pi_info_cols["state_in_{}".format(f_i)] = column - for f_i, column in enumerate(rnn_out_cols): - pi_info_cols["state_out_{}".format(f_i)] = column - # Save output rows - for i, action in enumerate(actions): - env_id = eval_data[i].env_id - agent_id = eval_data[i].agent_id - actions_to_send[env_id][agent_id] = action - episode = active_episodes[env_id] - episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) - episode._set_last_pi_info( - agent_id, {k: v[i] - for k, v in pi_info_cols.items()}) - if env_id in off_policy_actions and \ - agent_id in off_policy_actions[env_id]: - episode._set_last_action( - agent_id, off_policy_actions[env_id][agent_id]) - else: - episode._set_last_action(agent_id, action) + episode._set_last_action(agent_id, action) - # Return computed actions to ready envs. We also send to envs that have - # taken off-policy actions; those envs are free to ignore the action. - async_vector_env.send_actions(dict(actions_to_send)) + return actions_to_send def _fetch_atari_metrics(async_vector_env): @@ -405,10 +554,48 @@ def _fetch_atari_metrics(async_vector_env): if not monitor: return None for eps_rew, eps_len in monitor.next_episode_results(): - atari_out.append(RolloutMetrics(eps_len, eps_rew, {})) + atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {})) return atari_out +def _clip_actions(actions, space): + """Called to clip actions to the specified range of this policy. + + Arguments: + actions: Batch of actions or TupleActions. + space: Action space the actions should be present in. + + Returns: + Clipped batch of actions. + """ + + if isinstance(space, gym.spaces.Box): + return np.clip(actions, space.low, space.high) + elif isinstance(space, gym.spaces.Tuple): + if not isinstance(actions, TupleActions): + raise ValueError("Expected tuple space for actions {}: {}".format( + actions, space)) + out = [] + for a, s in zip(actions.batches, space.spaces): + out.append(_clip_actions(a, s)) + return TupleActions(out) + else: + return actions + + +def _unbatch_tuple_actions(action_batch): + # convert list of batches -> batch of lists + if isinstance(action_batch, TupleActions): + out = [] + for j in range(len(action_batch.batches[0])): + out.append([ + action_batch.batches[i][j] + for i in range(len(action_batch.batches)) + ]) + return out + return action_batch + + def _to_column_format(rnn_state_rows): num_cols = len(rnn_state_rows[0]) return [[row[i] for row in rnn_state_rows] for i in range(num_cols)] diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index e9119c87527b4..e5a1d7b19732b 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -2,14 +2,18 @@ from __future__ import division from __future__ import print_function +import logging import tensorflow as tf import numpy as np import ray from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.models.lstm import chop_into_sequences -from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.tf_run_builder import TFRunBuilder + +logger = logging.getLogger(__name__) class TFPolicyGraph(PolicyGraph): @@ -27,7 +31,7 @@ class TFPolicyGraph(PolicyGraph): Examples: >>> policy = TFPolicyGraphSubclass( - sess, obs_input, action_sampler, loss, loss_inputs, is_training) + sess, obs_input, action_sampler, loss, loss_inputs) >>> print(policy.compute_actions([1, 0, 2])) (array([0, 1, 1]), [], {}) @@ -46,8 +50,12 @@ def __init__(self, loss_inputs, state_inputs=None, state_outputs=None, + prev_action_input=None, + prev_reward_input=None, seq_lens=None, - max_seq_len=20): + max_seq_len=20, + batch_divisibility_req=1, + update_ops=None): """Initialize the policy graph. Arguments: @@ -62,174 +70,261 @@ def __init__(self, loss_inputs (list): a (name, placeholder) tuple for each loss input argument. Each placeholder name must correspond to a SampleBatch column key returned by postprocess_trajectory(), - and has shape [BATCH_SIZE, data...]. + and has shape [BATCH_SIZE, data...]. These keys will be read + from postprocessed sample batches and fed into the specified + placeholders during loss computation. state_inputs (list): list of RNN state input Tensors. state_outputs (list): list of RNN state output Tensors. + prev_action_input (Tensor): placeholder for previous actions + prev_reward_input (Tensor): placeholder for previous rewards seq_lens (Tensor): placeholder for RNN sequence lengths, of shape [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See models/lstm.py for more information. max_seq_len (int): max sequence length for LSTM training. + batch_divisibility_req (int): pad all agent experiences batches to + multiples of this value. This only has an effect if not using + a LSTM model. + update_ops (list): override the batchnorm update ops to run when + applying gradients. Otherwise we run all update ops found in + the current variable scope. """ self.observation_space = observation_space self.action_space = action_space self._sess = sess self._obs_input = obs_input + self._prev_action_input = prev_action_input + self._prev_reward_input = prev_reward_input self._sampler = action_sampler self._loss = loss self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) - self._is_training = tf.placeholder_with_default(True, ()) + self._is_training = self._get_is_training_placeholder() self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] for i, ph in enumerate(self._state_inputs): self._loss_input_dict["state_in_{}".format(i)] = ph self._seq_lens = seq_lens self._max_seq_len = max_seq_len + self._batch_divisibility_req = batch_divisibility_req + self._optimizer = self.optimizer() self._grads_and_vars = [(g, v) for (g, v) in self.gradients(self._optimizer) if g is not None] self._grads = [g for (g, v) in self._grads_and_vars] - self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars) self._variables = ray.experimental.TensorFlowVariables( self._loss, self._sess) - assert len(self._state_inputs) == len(self._state_outputs) == \ - len(self.get_initial_state()), \ - (self._state_inputs, self._state_outputs, self.get_initial_state()) - if self._state_inputs: - assert self._seq_lens is not None - - def build_compute_actions(self, - builder, - obs_batch, - state_batches=None, - is_training=False, - episodes=None): - state_batches = state_batches or [] - assert len(self._state_inputs) == len(state_batches), \ - (self._state_inputs, state_batches) - builder.add_feed_dict(self.extra_compute_action_feed_dict()) - builder.add_feed_dict({self._obs_input: obs_batch}) - if state_batches: - builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) - builder.add_feed_dict({self._is_training: is_training}) - builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) - fetches = builder.add_fetches([self._sampler] + self._state_outputs + - [self.extra_compute_action_fetches()]) - return fetches[0], fetches[1:-1], fetches[-1] - + # gather update ops for any batch norm layers + if update_ops: + self._update_ops = update_ops + else: + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + with tf.control_dependencies(self._update_ops): + # specify global_step for TD3 which needs to count the num updates + self._apply_op = self._optimizer.apply_gradients( + self._grads_and_vars, + global_step=tf.train.get_or_create_global_step()) + + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") + + logger.debug("Created {} with loss inputs: {}".format( + self, self._loss_input_dict)) + + @override(PolicyGraph) def compute_actions(self, obs_batch, state_batches=None, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): builder = TFRunBuilder(self._sess, "compute_actions") - fetches = self.build_compute_actions(builder, obs_batch, state_batches, - is_training) + fetches = self._build_compute_actions(builder, obs_batch, + state_batches, prev_action_batch, + prev_reward_batch) return builder.get(fetches) - def _get_loss_inputs_dict(self, batch): - feed_dict = {} - - # Simple case - if not self._state_inputs: - for k, ph in self._loss_inputs: - feed_dict[ph] = batch[k] - return feed_dict - - # RNN case - feature_keys = [k for k, v in self._loss_inputs] - state_keys = [ - "state_in_{}".format(i) for i in range(len(self._state_inputs)) - ] - feature_sequences, initial_states, seq_lens = chop_into_sequences( - batch["eps_id"], [batch[k] for k in feature_keys], - [batch[k] for k in state_keys], self._max_seq_len) - for k, v in zip(feature_keys, feature_sequences): - feed_dict[self._loss_input_dict[k]] = v - for k, v in zip(state_keys, initial_states): - feed_dict[self._loss_input_dict[k]] = v - feed_dict[self._seq_lens] = seq_lens - return feed_dict - - def build_compute_gradients(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - fetches = builder.add_fetches( - [self._grads, self.extra_compute_grad_fetches()]) - return fetches[0], fetches[1] - + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_gradients") - fetches = self.build_compute_gradients(builder, postprocessed_batch) + fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) - def build_apply_gradients(self, builder, gradients): - assert len(gradients) == len(self._grads), (gradients, self._grads) - builder.add_feed_dict(self.extra_apply_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(dict(zip(self._grads, gradients))) - fetches = builder.add_fetches( - [self._apply_op, self.extra_apply_grad_fetches()]) - return fetches[1] - + @override(PolicyGraph) def apply_gradients(self, gradients): builder = TFRunBuilder(self._sess, "apply_gradients") - fetches = self.build_apply_gradients(builder, gradients) + fetches = self._build_apply_gradients(builder, gradients) return builder.get(fetches) - def build_compute_apply(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict(self.extra_apply_grad_feed_dict()) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - builder.add_feed_dict({self._is_training: True}) - fetches = builder.add_fetches([ - self._apply_op, - self.extra_compute_grad_fetches(), - self.extra_apply_grad_fetches() - ]) - return fetches[1], fetches[2] - + @override(PolicyGraph) def compute_apply(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_apply") - fetches = self.build_compute_apply(builder, postprocessed_batch) + fetches = self._build_compute_apply(builder, postprocessed_batch) return builder.get(fetches) + @override(PolicyGraph) def get_weights(self): return self._variables.get_flat() + @override(PolicyGraph) def set_weights(self, weights): return self._variables.set_flat(weights) + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders. + + Optional, only required to work with the multi-GPU optimizer.""" + raise NotImplementedError + def extra_compute_action_feed_dict(self): + """Extra dict to pass to the compute actions session run.""" return {} def extra_compute_action_fetches(self): + """Extra values to fetch and return from compute_actions().""" return {} # e.g, value function def extra_compute_grad_feed_dict(self): + """Extra dict to pass to the compute gradients session run.""" return {} # e.g, kl_coeff def extra_compute_grad_fetches(self): + """Extra values to fetch and return from compute_gradients().""" return {} # e.g, td error def extra_apply_grad_feed_dict(self): + """Extra dict to pass to the apply gradients session run.""" return {} def extra_apply_grad_fetches(self): + """Extra values to fetch and return from apply_gradients().""" return {} # e.g., batch norm updates def optimizer(self): + """TF optimizer to use for policy optimization.""" return tf.train.AdamOptimizer() def gradients(self, optimizer): + """Override for custom gradient computation.""" return optimizer.compute_gradients(self._loss) - def loss_inputs(self): - return self._loss_inputs + def _build_compute_actions(self, + builder, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None): + state_batches = state_batches or [] + assert len(self._state_inputs) == len(state_batches), \ + (self._state_inputs, state_batches) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + builder.add_feed_dict({self._obs_input: obs_batch}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + if self._prev_action_input is not None and prev_action_batch: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + builder.add_feed_dict({self._is_training: False}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + fetches = builder.add_fetches([self._sampler] + self._state_outputs + + [self.extra_compute_action_fetches()]) + return fetches[0], fetches[1:-1], fetches[-1] + + def _build_compute_gradients(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + fetches = builder.add_fetches( + [self._grads, self.extra_compute_grad_fetches()]) + return fetches[0], fetches[1] + + def _build_apply_gradients(self, builder, gradients): + assert len(gradients) == len(self._grads), (gradients, self._grads) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches( + [self._apply_op, self.extra_apply_grad_fetches()]) + return fetches[1] + + def _build_compute_apply(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict({self._is_training: True}) + fetches = builder.add_fetches([ + self._apply_op, + self.extra_compute_grad_fetches(), + self.extra_apply_grad_fetches() + ]) + return fetches[1], fetches[2] + + def _get_is_training_placeholder(self): + """Get the placeholder for _is_training, i.e., for batch norm layers. + + This can be called safely before __init__ has run. + """ + if not hasattr(self, "_is_training"): + self._is_training = tf.placeholder_with_default(False, ()) + return self._is_training + + def _get_loss_inputs_dict(self, batch): + feed_dict = {} + if self._batch_divisibility_req > 1: + meets_divisibility_reqs = ( + len(batch["obs"]) % self._batch_divisibility_req == 0 + and max(batch["agent_index"]) == 0) # not multiagent + else: + meets_divisibility_reqs = True + + # Simple case: not RNN nor do we need to pad + if not self._state_inputs and meets_divisibility_reqs: + for k, ph in self._loss_inputs: + feed_dict[ph] = batch[k] + return feed_dict + + if self._state_inputs: + max_seq_len = self._max_seq_len + dynamic_max = True + else: + max_seq_len = self._batch_divisibility_req + dynamic_max = False + + # RNN or multi-agent case + feature_keys = [k for k, v in self._loss_inputs] + state_keys = [ + "state_in_{}".format(i) for i in range(len(self._state_inputs)) + ] + feature_sequences, initial_states, seq_lens = chop_into_sequences( + batch["eps_id"], + batch["agent_index"], [batch[k] for k in feature_keys], + [batch[k] for k in state_keys], + max_seq_len, + dynamic_max=dynamic_max) + for k, v in zip(feature_keys, feature_sequences): + feed_dict[self._loss_input_dict[k]] = v + for k, v in zip(state_keys, initial_states): + feed_dict[self._loss_input_dict[k]] = v + feed_dict[self._seq_lens] = seq_lens + return feed_dict class LearningRateSchedule(object): @@ -243,11 +338,13 @@ def __init__(self, lr, lr_schedule): self.lr_schedule = PiecewiseSchedule( lr_schedule, outside_value=lr_schedule[-1][-1]) + @override(PolicyGraph) def on_global_var_update(self, global_vars): super(LearningRateSchedule, self).on_global_var_update(global_vars) self.cur_lr.load( self.lr_schedule.value(global_vars["timestep"]), session=self._sess) + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 741357f3aa8d5..c8e86e8451c73 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -13,6 +13,7 @@ pass # soft dep from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.utils.annotations import override class TorchPolicyGraph(PolicyGraph): @@ -56,21 +57,12 @@ def __init__(self, observation_space, action_space, model, loss, self._loss_inputs = loss_inputs self._optimizer = self.optimizer() - def extra_action_out(self, model_out): - """Returns dict of extra info to include in experience batch. - - Arguments: - model_out (list): Outputs of the policy model module.""" - return {} - - def optimizer(self): - """Custom PyTorch optimizer to use.""" - return torch.optim.Adam(self._model.parameters()) - + @override(PolicyGraph) def compute_actions(self, obs_batch, state_batches=None, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): if state_batches: raise NotImplementedError("Torch RNN support") @@ -82,6 +74,7 @@ def compute_actions(self, actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0) return var_to_np(actions), [], self.extra_action_out(model_out) + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): with self.lock: loss_in = [] @@ -95,6 +88,7 @@ def compute_gradients(self, postprocessed_batch): grads = [var_to_np(p.grad.data) for p in self._model.parameters()] return grads, {} + @override(PolicyGraph) def apply_gradients(self, gradients): with self.lock: for g, p in zip(gradients, self._model.parameters()): @@ -102,10 +96,23 @@ def apply_gradients(self, gradients): self._optimizer.step() return {} + @override(PolicyGraph) def get_weights(self): with self.lock: return self._model.state_dict() + @override(PolicyGraph) def set_weights(self, weights): with self.lock: self._model.load_state_dict(weights) + + def extra_action_out(self, model_out): + """Returns dict of extra info to include in experience batch. + + Arguments: + model_out (list): Outputs of the policy model module.""" + return {} + + def optimizer(self): + """Custom PyTorch optimizer to use.""" + return torch.optim.Adam(self._model.parameters()) diff --git a/python/ray/rllib/examples/batch_norm_model.py b/python/ray/rllib/examples/batch_norm_model.py new file mode 100644 index 0000000000000..abd4b53666a2a --- /dev/null +++ b/python/ray/rllib/examples/batch_norm_model.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +"""Example of using a custom model with batch norm.""" + +import argparse + +import tensorflow as tf +import tensorflow.contrib.slim as slim + +import ray +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.tune import run_experiments + +parser = argparse.ArgumentParser() +parser.add_argument("--num-iters", type=int, default=200) +parser.add_argument("--run", type=str, default="PPO") + + +class BatchNormModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + last_layer = input_dict["obs"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + # Add a batch norm layer + last_layer = tf.layers.batch_normalization( + last_layer, training=input_dict["is_training"]) + output = slim.fully_connected( + last_layer, + num_outputs, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + return output, last_layer + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + ModelCatalog.register_custom_model("bn_model", BatchNormModel) + run_experiments({ + "batch_norm_demo": { + "run": args.run, + "env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0", + "stop": { + "training_iteration": args.num_iters + }, + "config": { + "model": { + "custom_model": "bn_model", + }, + "num_workers": 0, + }, + }, + }) diff --git a/examples/carla/README b/python/ray/rllib/examples/carla/README similarity index 100% rename from examples/carla/README rename to python/ray/rllib/examples/carla/README diff --git a/examples/carla/env.py b/python/ray/rllib/examples/carla/env.py similarity index 83% rename from examples/carla/env.py rename to python/ray/rllib/examples/carla/env.py index c88a71b28f51b..af5b619afcdb4 100644 --- a/examples/carla/env.py +++ b/python/ray/rllib/examples/carla/env.py @@ -33,8 +33,8 @@ os.makedirs(CARLA_OUT_PATH) # Set this to the path of your Carla binary -SERVER_BINARY = os.environ.get( - "CARLA_SERVER", os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) +SERVER_BINARY = os.environ.get("CARLA_SERVER", + os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) assert os.path.exists(SERVER_BINARY) if "CARLA_PY_PATH" in os.environ: @@ -97,7 +97,6 @@ "squash_action_logits": False, } - DISCRETE_ACTIONS = { # coast 0: [0.0, 0.0], @@ -119,7 +118,6 @@ 8: [-0.5, 0.5], } - live_carla_processes = set() @@ -133,7 +131,6 @@ def cleanup(): class CarlaEnv(gym.Env): - def __init__(self, config=ENV_CONFIG): self.config = config self.city = self.config["server_map"].split("/")[-1] @@ -143,21 +140,27 @@ def __init__(self, config=ENV_CONFIG): if config["discrete_actions"]: self.action_space = Discrete(len(DISCRETE_ACTIONS)) else: - self.action_space = Box(-1.0, 1.0, shape=(2,), dtype=np.float32) + self.action_space = Box(-1.0, 1.0, shape=(2, ), dtype=np.float32) if config["use_depth_camera"]: image_space = Box( - -1.0, 1.0, shape=( - config["y_res"], config["x_res"], - 1 * config["framestack"]), dtype=np.float32) + -1.0, + 1.0, + shape=(config["y_res"], config["x_res"], + 1 * config["framestack"]), + dtype=np.float32) else: image_space = Box( - 0, 255, shape=( - config["y_res"], config["x_res"], - 3 * config["framestack"]), dtype=np.uint8) + 0, + 255, + shape=(config["y_res"], config["x_res"], + 3 * config["framestack"]), + dtype=np.uint8) self.observation_space = Tuple( # forward_speed, dist to goal - [image_space, - Discrete(len(COMMANDS_ENUM)), # next_command - Box(-128.0, 128.0, shape=(2,), dtype=np.float32)]) + [ + image_space, + Discrete(len(COMMANDS_ENUM)), # next_command + Box(-128.0, 128.0, shape=(2, ), dtype=np.float32) + ]) # TODO(ekl) this isn't really a proper gym spec self._spec = lambda: None @@ -185,11 +188,13 @@ def init_server(self): # Create a new server process and start the client. self.server_port = random.randint(10000, 60000) self.server_process = subprocess.Popen( - [SERVER_BINARY, self.config["server_map"], - "-windowed", "-ResX=400", "-ResY=300", - "-carla-server", - "-carla-world-port={}".format(self.server_port)], - preexec_fn=os.setsid, stdout=open(os.devnull, "w")) + [ + SERVER_BINARY, self.config["server_map"], "-windowed", + "-ResX=400", "-ResY=300", "-carla-server", + "-carla-world-port={}".format(self.server_port) + ], + preexec_fn=os.setsid, + stdout=open(os.devnull, "w")) live_carla_processes.add(os.getpgid(self.server_process.pid)) for i in range(RETRIES_ON_ERROR): @@ -257,14 +262,14 @@ def _reset(self): if self.config["use_depth_camera"]: camera1 = Camera("CameraDepth", PostProcessing="Depth") - camera1.set_image_size( - self.config["render_x_res"], self.config["render_y_res"]) + camera1.set_image_size(self.config["render_x_res"], + self.config["render_y_res"]) camera1.set_position(30, 0, 130) settings.add_sensor(camera1) camera2 = Camera("CameraRGB") - camera2.set_image_size( - self.config["render_x_res"], self.config["render_y_res"]) + camera2.set_image_size(self.config["render_x_res"], + self.config["render_y_res"]) camera2.set_position(30, 0, 130) settings.add_sensor(camera2) @@ -274,13 +279,14 @@ def _reset(self): self.start_pos = positions[self.scenario["start_pos_id"]] self.end_pos = positions[self.scenario["end_pos_id"]] self.start_coord = [ - self.start_pos.location.x // 100, self.start_pos.location.y // 100] + self.start_pos.location.x // 100, self.start_pos.location.y // 100 + ] self.end_coord = [ - self.end_pos.location.x // 100, self.end_pos.location.y // 100] - print( - "Start pos {} ({}), end {} ({})".format( - self.scenario["start_pos_id"], self.start_coord, - self.scenario["end_pos_id"], self.end_coord)) + self.end_pos.location.x // 100, self.end_pos.location.y // 100 + ] + print("Start pos {} ({}), end {} ({})".format( + self.scenario["start_pos_id"], self.start_coord, + self.scenario["end_pos_id"], self.end_coord)) # Notify the server that we want to start the episode at the # player_start index. This function blocks until the server is ready @@ -300,11 +306,10 @@ def encode_obs(self, image, py_measurements): prev_image = image if self.config["framestack"] == 2: image = np.concatenate([prev_image, image], axis=2) - obs = ( - image, - COMMAND_ORDINAL[py_measurements["next_command"]], - [py_measurements["forward_speed"], - py_measurements["distance_to_goal"]]) + obs = (image, COMMAND_ORDINAL[py_measurements["next_command"]], [ + py_measurements["forward_speed"], + py_measurements["distance_to_goal"] + ]) self.last_obs = obs return obs @@ -313,9 +318,8 @@ def step(self, action): obs = self._step(action) return obs except Exception: - print( - "Error during step, terminating episode early", - traceback.format_exc()) + print("Error during step, terminating episode early", + traceback.format_exc()) self.clear_server_state() return (self.last_obs, 0.0, True, {}) @@ -336,12 +340,14 @@ def _step(self, action): hand_brake = False if self.config["verbose"]: - print( - "steer", steer, "throttle", throttle, "brake", brake, - "reverse", reverse) + print("steer", steer, "throttle", throttle, "brake", brake, + "reverse", reverse) self.client.send_control( - steer=steer, throttle=throttle, brake=brake, hand_brake=hand_brake, + steer=steer, + throttle=throttle, + brake=brake, + hand_brake=hand_brake, reverse=reverse) # Process observations @@ -359,15 +365,14 @@ def _step(self, action): "reverse": reverse, "hand_brake": hand_brake, } - reward = compute_reward( - self, self.prev_measurement, py_measurements) + reward = compute_reward(self, self.prev_measurement, py_measurements) self.total_reward += reward py_measurements["reward"] = reward py_measurements["total_reward"] = self.total_reward - done = (self.num_steps > self.scenario["max_steps"] or - py_measurements["next_command"] == "REACH_GOAL" or - (self.config["early_terminate_on_collision"] and - collided_done(py_measurements))) + done = (self.num_steps > self.scenario["max_steps"] + or py_measurements["next_command"] == "REACH_GOAL" + or (self.config["early_terminate_on_collision"] + and collided_done(py_measurements))) py_measurements["done"] = done self.prev_measurement = py_measurements @@ -377,8 +382,7 @@ def _step(self, action): self.measurements_file = open( os.path.join( CARLA_OUT_PATH, - "measurements_{}.json".format(self.episode_id)), - "w") + "measurements_{}.json".format(self.episode_id)), "w") self.measurements_file.write(json.dumps(py_measurements)) self.measurements_file.write("\n") if done: @@ -389,9 +393,8 @@ def _step(self, action): self.num_steps += 1 image = self.preprocess_image(image) - return ( - self.encode_obs(image, py_measurements), reward, done, - py_measurements) + return (self.encode_obs(image, py_measurements), reward, done, + py_measurements) def images_to_video(self): videos_dir = os.path.join(CARLA_OUT_PATH, "Videos") @@ -413,15 +416,15 @@ def preprocess_image(self, image): if self.config["use_depth_camera"]: assert self.config["use_depth_camera"] data = (image.data - 0.5) * 2 - data = data.reshape( - self.config["render_y_res"], self.config["render_x_res"], 1) + data = data.reshape(self.config["render_y_res"], + self.config["render_x_res"], 1) data = cv2.resize( data, (self.config["x_res"], self.config["y_res"]), interpolation=cv2.INTER_AREA) data = np.expand_dims(data, 2) else: - data = image.data.reshape( - self.config["render_y_res"], self.config["render_x_res"], 3) + data = image.data.reshape(self.config["render_y_res"], + self.config["render_x_res"], 3) data = cv2.resize( data, (self.config["x_res"], self.config["y_res"]), interpolation=cv2.INTER_AREA) @@ -448,36 +451,39 @@ def _read_observation(self): cur = measurements.player_measurements if self.config["enable_planner"]: - next_command = COMMANDS_ENUM[ - self.planner.get_next_command( - [cur.transform.location.x, cur.transform.location.y, - GROUND_Z], - [cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z], - [self.end_pos.location.x, self.end_pos.location.y, - GROUND_Z], - [self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z]) - ] + next_command = COMMANDS_ENUM[self.planner.get_next_command( + [cur.transform.location.x, cur.transform.location.y, GROUND_Z], + [ + cur.transform.orientation.x, cur.transform.orientation.y, + GROUND_Z + ], + [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ + self.end_pos.orientation.x, self.end_pos.orientation.y, + GROUND_Z + ])] else: next_command = "LANE_FOLLOW" if next_command == "REACH_GOAL": distance_to_goal = 0.0 # avoids crash in planner elif self.config["enable_planner"]: - distance_to_goal = self.planner.get_shortest_path_distance( - [cur.transform.location.x, cur.transform.location.y, GROUND_Z], - [cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z], - [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], - [self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z]) / 100 + distance_to_goal = self.planner.get_shortest_path_distance([ + cur.transform.location.x, cur.transform.location.y, GROUND_Z + ], [ + cur.transform.orientation.x, cur.transform.orientation.y, + GROUND_Z + ], [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ + self.end_pos.orientation.x, self.end_pos.orientation.y, + GROUND_Z + ]) / 100 else: distance_to_goal = -1 - distance_to_goal_euclidean = float(np.linalg.norm( - [cur.transform.location.x - self.end_pos.location.x, - cur.transform.location.y - self.end_pos.location.y]) / 100) + distance_to_goal_euclidean = float( + np.linalg.norm([ + cur.transform.location.x - self.end_pos.location.x, + cur.transform.location.y - self.end_pos.location.y + ]) / 100) py_measurements = { "episode_id": self.episode_id, @@ -513,8 +519,8 @@ def _read_observation(self): if not os.path.exists(out_dir): os.makedirs(out_dir) out_file = os.path.join( - out_dir, - "{}_{:>04}.jpg".format(self.episode_id, self.num_steps)) + out_dir, "{}_{:>04}.jpg".format(self.episode_id, + self.num_steps)) scipy.misc.imsave(out_file, image.data) assert observation is not None, sensor_data @@ -621,8 +627,7 @@ def compute_reward_lane_keep(env, prev, current): def compute_reward(env, prev, current): - return REWARD_FUNCTIONS[env.config["reward_function"]]( - env, prev, current) + return REWARD_FUNCTIONS[env.config["reward_function"]](env, prev, current) def print_measurements(measurements): @@ -654,9 +659,8 @@ def sigmoid(x): def collided_done(py_measurements): m = py_measurements - collided = ( - m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 or - m["collision_other"] > 0) + collided = (m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 + or m["collision_other"] > 0) return bool(collided or m["total_reward"] < -100) diff --git a/examples/carla/models.py b/python/ray/rllib/examples/carla/models.py similarity index 81% rename from examples/carla/models.py rename to python/ray/rllib/examples/carla/models.py index 9233c9c8ed2be..3f8cc0c5ba47b 100644 --- a/examples/carla/models.py +++ b/python/ray/rllib/examples/carla/models.py @@ -20,6 +20,7 @@ class CarlaModel(Model): further fully connected layers. """ + # TODO(ekl): use build_layers_v2 for native dict space support def _build_layers(self, inputs, num_outputs, options): # Parse options image_shape = options["custom_options"]["image_shape"] @@ -43,8 +44,8 @@ def _build_layers(self, inputs, num_outputs, options): (inputs.shape.as_list()[1:], expected_shape) # Reshape the input vector back into its components - vision_in = tf.reshape( - inputs[:, :image_size], [tf.shape(inputs)[0]] + image_shape) + vision_in = tf.reshape(inputs[:, :image_size], + [tf.shape(inputs)[0]] + image_shape) metrics_in = inputs[:, image_size:] print("Vision in shape", vision_in) print("Metrics in shape", metrics_in) @@ -53,18 +54,26 @@ def _build_layers(self, inputs, num_outputs, options): with tf.name_scope("carla_vision"): for i, (out_size, kernel, stride) in enumerate(convs[:-1], 1): vision_in = slim.conv2d( - vision_in, out_size, kernel, stride, + vision_in, + out_size, + kernel, + stride, scope="conv{}".format(i)) out_size, kernel, stride = convs[-1] vision_in = slim.conv2d( - vision_in, out_size, kernel, stride, - padding="VALID", scope="conv_out") + vision_in, + out_size, + kernel, + stride, + padding="VALID", + scope="conv_out") vision_in = tf.squeeze(vision_in, [1, 2]) # Setup metrics layer with tf.name_scope("carla_metrics"): metrics_in = slim.fully_connected( - metrics_in, 64, + metrics_in, + 64, weights_initializer=xavier_initializer(), activation_fn=activation, scope="metrics_out") @@ -79,15 +88,18 @@ def _build_layers(self, inputs, num_outputs, options): print("Shape of concatenated out is", last_layer.shape) for size in hiddens: last_layer = slim.fully_connected( - last_layer, size, + last_layer, + size, weights_initializer=xavier_initializer(), activation_fn=activation, scope="fc{}".format(i)) i += 1 output = slim.fully_connected( - last_layer, num_outputs, + last_layer, + num_outputs, weights_initializer=normc_initializer(0.01), - activation_fn=None, scope="fc_out") + activation_fn=None, + scope="fc_out") return output, last_layer diff --git a/python/ray/rllib/examples/carla/scenarios.py b/python/ray/rllib/examples/carla/scenarios.py new file mode 100644 index 0000000000000..beedd2989d5cf --- /dev/null +++ b/python/ray/rllib/examples/carla/scenarios.py @@ -0,0 +1,131 @@ +"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" + +TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] +TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] + + +def build_scenario(city, start, end, vehicles, pedestrians, max_steps, + weathers): + return { + "city": city, + "num_vehicles": vehicles, + "num_pedestrians": pedestrians, + "weather_distribution": weathers, + "start_pos_id": start, + "end_pos_id": end, + "max_steps": max_steps, + } + + +# Simple scenario for Town02 that involves driving down a road +DEFAULT_SCENARIO = build_scenario( + city="Town02", + start=36, + end=40, + vehicles=20, + pedestrians=40, + max_steps=200, + weathers=[0]) + +# Simple scenario for Town02 that involves driving down a road +LANE_KEEP = build_scenario( + city="Town02", + start=36, + end=40, + vehicles=0, + pedestrians=0, + max_steps=2000, + weathers=[0]) + +# Scenarios from the CoRL2017 paper +POSES_TOWN1_STRAIGHT = [[36, 40], [39, 35], [110, 114], [7, 3], [0, 4], [ + 68, 50 +], [61, 59], [47, 64], [147, 90], [33, 87], [26, 19], [80, 76], [45, 49], [ + 55, 44 +], [29, 107], [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], [20, 107], + [78, 70], [95, 102], [68, 44], [45, 69]] + +POSES_TOWN1_ONE_CURVE = [[138, 17], [47, 16], [26, 9], [42, 49], [140, 124], [ + 85, 98 +], [65, 133], [137, 51], [76, 66], [46, 39], [40, 60], [0, 29], [4, 129], [ + 121, 140 +], [2, 129], [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], [84, 69], + [47, 79], [110, 15], [130, 17], [0, 17]] + +POSES_TOWN1_NAV = [[105, 29], [27, 130], [102, 87], [132, 27], [24, 44], [ + 96, 26 +], [34, 67], [28, 1], [140, 134], [105, 9], [148, 129], [65, 18], [21, 16], [ + 147, 97 +], [42, 51], [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], [111, 64], + [79, 45], [84, 69], [73, 31], [37, 81]] + +POSES_TOWN2_STRAIGHT = [[38, 34], [4, 2], [12, 10], [62, 55], [43, 47], [ + 64, 66 +], [78, 76], [59, 57], [61, 18], [35, 39], [12, 8], [0, 18], [75, 68], [ + 54, 60 +], [45, 49], [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], [54, 63], + [51, 42], [16, 19], [17, 26], [77, 68]] + +POSES_TOWN2_ONE_CURVE = [[37, 76], [8, 24], [60, 69], [38, 10], [21, 1], [ + 58, 71 +], [74, 32], [44, 0], [71, 16], [14, 24], [34, 11], [43, 14], [75, 16], [ + 80, 21 +], [3, 23], [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], [40, 63], + [58, 76], [79, 55], [16, 61], [27, 11]] + +POSES_TOWN2_NAV = [[19, 66], [79, 14], [19, 57], [23, 1], [53, 76], [42, 13], [ + 31, 71 +], [33, 5], [54, 30], [10, 61], [66, 3], [27, 12], [79, 19], [2, 29], [16, 14], + [5, 57], [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], + [51, 81], [77, 68], [56, 65], [43, 54]] + +TOWN1_STRAIGHT = [ + build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_STRAIGHT +] + +TOWN1_ONE_CURVE = [ + build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_ONE_CURVE +] + +TOWN1_NAVIGATION = [ + build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_NAV +] + +TOWN1_NAVIGATION_DYNAMIC = [ + build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_NAV +] + +TOWN2_STRAIGHT = [ + build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_STRAIGHT +] + +TOWN2_STRAIGHT_DYNAMIC = [ + build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_STRAIGHT +] + +TOWN2_ONE_CURVE = [ + build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_ONE_CURVE +] + +TOWN2_NAVIGATION = [ + build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_NAV +] + +TOWN2_NAVIGATION_DYNAMIC = [ + build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_NAV +] + +TOWN1_ALL = (TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + + TOWN1_NAVIGATION_DYNAMIC) + +TOWN2_ALL = (TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + + TOWN2_NAVIGATION_DYNAMIC) diff --git a/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py similarity index 84% rename from examples/carla/train_a3c.py rename to python/ray/rllib/examples/carla/train_a3c.py index 75856aef266e0..8fbcfbc576d1e 100644 --- a/examples/carla/train_a3c.py +++ b/python/ray/rllib/examples/carla/train_a3c.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import grid_search, register_env, run_experiments +from ray.tune import grid_search, run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -23,7 +22,6 @@ "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() redis_address = ray.services.get_node_ip_address() + ":6379" @@ -31,8 +29,7 @@ run_experiments({ "carla-a3c": { "run": "A3C", - "env": "carla_env", - "trial_resources": {"cpu": 5, "extra_gpu": 2}, + "env": CarlaEnv, "config": { "env_config": env_config, "use_gpu_for_workers": True, diff --git a/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py similarity index 71% rename from examples/carla/train_dqn.py rename to python/ray/rllib/examples/carla/train_dqn.py index 6180ca48f0dda..27aa65444d386 100644 --- a/examples/carla/train_dqn.py +++ b/python/ray/rllib/examples/carla/train_dqn.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_ONE_CURVE -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -21,25 +20,29 @@ "scenarios": TOWN2_ONE_CURVE, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init() + + +def shape_out(spec): + return (spec.config.env_config.framestack * + (spec.config.env_config.use_depth_camera and 1 or 3)) + + run_experiments({ "carla-dqn": { "run": "DQN", - "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, + "env": CarlaEnv, "config": { "env_config": env_config, "model": { "custom_model": "carla", "custom_options": { "image_shape": [ - 80, 80, - lambda spec: spec.config.env_config.framestack * ( - spec.config.env_config.use_depth_camera and 1 or 3 - ), + 80, + 80, + shape_out, ], }, "conv_filters": [ @@ -53,7 +56,9 @@ "schedule_max_timesteps": 100000, "gamma": 0.8, "tf_session_args": { - "gpu_options": {"allow_growth": True}, + "gpu_options": { + "allow_growth": True + }, }, }, }, diff --git a/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py similarity index 70% rename from examples/carla/train_ppo.py rename to python/ray/rllib/examples/carla/train_ppo.py index 4f3ebf5eab830..6c49240142c26 100644 --- a/examples/carla/train_ppo.py +++ b/python/ray/rllib/examples/carla/train_ppo.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -20,22 +19,21 @@ "server_map": "/Game/Maps/Town02", "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init(redirect_output=True) run_experiments({ "carla": { "run": "PPO", - "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, + "env": CarlaEnv, "config": { "env_config": env_config, "model": { "custom_model": "carla", "custom_options": { "image_shape": [ - env_config["x_res"], env_config["y_res"], 6], + env_config["x_res"], env_config["y_res"], 6 + ], }, "conv_filters": [ [16, [8, 8], 4], @@ -44,17 +42,14 @@ ], }, "num_workers": 1, - "timesteps_per_batch": 2000, - "min_steps_per_task": 100, + "train_batch_size": 2000, + "sample_batch_size": 100, "lambda": 0.95, "clip_param": 0.2, "num_sgd_iter": 20, - "sgd_stepsize": 0.0001, - "sgd_batchsize": 32, - "devices": ["/gpu:0"], - "tf_session_args": { - "gpu_options": {"allow_growth": True} - } + "lr": 0.0001, + "sgd_minibatch_size": 32, + "num_gpus": 1, }, }, }) diff --git a/python/ray/rllib/examples/cartpole_lstm.py b/python/ray/rllib/examples/cartpole_lstm.py index e3d0ddc4c5701..ddc89c47e3b34 100644 --- a/python/ray/rllib/examples/cartpole_lstm.py +++ b/python/ray/rllib/examples/cartpole_lstm.py @@ -14,6 +14,8 @@ parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=200) +parser.add_argument("--use-prev-action-reward", action="store_true") +parser.add_argument("--run", type=str, default="PPO") class CartPoleStatelessEnv(gym.Env): @@ -163,18 +165,32 @@ def close(self): tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv()) ray.init() + + configs = { + "PPO": { + "num_sgd_iter": 5, + }, + "IMPALA": { + "num_workers": 2, + "num_gpus": 0, + "vf_loss_coeff": 0.01, + }, + } + tune.run_experiments({ "test": { "env": "cartpole_stateless", - "run": "PPO", + "run": args.run, "stop": { "episode_reward_mean": args.stop }, - "config": { - "num_sgd_iter": 5, - "model": { - "use_lstm": True, - }, - }, + "config": dict( + configs[args.run], **{ + "model": { + "use_lstm": True, + "lstm_use_prev_action_reward": args. + use_prev_action_reward, + }, + }), } }) diff --git a/examples/custom_env/custom_env.py b/python/ray/rllib/examples/custom_env.py similarity index 81% rename from examples/custom_env/custom_env.py rename to python/ray/rllib/examples/custom_env.py index b5a3240eaad0a..0d96eef6acb64 100644 --- a/examples/custom_env/custom_env.py +++ b/python/ray/rllib/examples/custom_env.py @@ -11,7 +11,6 @@ import ray from ray.tune import run_experiments -from ray.tune.registry import register_env class SimpleCorridor(gym.Env): @@ -24,7 +23,7 @@ def __init__(self, config): self.cur_pos = 0 self.action_space = Discrete(2) self.observation_space = Box( - 0.0, self.end_pos, shape=(1,), dtype=np.float32) + 0.0, self.end_pos, shape=(1, ), dtype=np.float32) self._spec = EnvSpec("SimpleCorridor-{}-v0".format(self.end_pos)) def reset(self): @@ -32,7 +31,7 @@ def reset(self): return [self.cur_pos] def step(self, action): - assert action in [0, 1] + assert action in [0, 1], action if action == 0 and self.cur_pos > 0: self.cur_pos -= 1 elif action == 1: @@ -42,13 +41,13 @@ def step(self, action): if __name__ == "__main__": - env_creator_name = "corridor" - register_env(env_creator_name, lambda config: SimpleCorridor(config)) + # Can also register the env creator function explicitly with: + # register_env("corridor", lambda config: SimpleCorridor(config)) ray.init() run_experiments({ "demo": { "run": "PPO", - "env": "corridor", + "env": SimpleCorridor, # or "corridor" if registered above "config": { "env_config": { "corridor_length": 5, diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py new file mode 100644 index 0000000000000..af1d25f16cadf --- /dev/null +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -0,0 +1,77 @@ +"""Example of using RLlib's debug callbacks. + +Here we use callbacks to track the average CartPole pole angle magnitude as a +custom metric. +""" + +import argparse +import numpy as np + +import ray +from ray import tune + + +def on_episode_start(info): + episode = info["episode"] + print("episode {} started".format(episode.episode_id)) + episode.user_data["pole_angles"] = [] + + +def on_episode_step(info): + episode = info["episode"] + pole_angle = abs(episode.last_observation_for()[2]) + episode.user_data["pole_angles"].append(pole_angle) + + +def on_episode_end(info): + episode = info["episode"] + pole_angle = np.mean(episode.user_data["pole_angles"]) + print("episode {} ended with length {} and pole angles {}".format( + episode.episode_id, episode.length, pole_angle)) + episode.custom_metrics["pole_angle"] = pole_angle + + +def on_sample_end(info): + print("returned sample batch of size {}".format(info["samples"].count)) + + +def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"], info["result"]["episodes_this_iter"])) + # you can mutate the result dict to add new fields to return + info["result"]["callback_ok"] = True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-iters", type=int, default=2000) + args = parser.parse_args() + + ray.init() + trials = tune.run_experiments({ + "test": { + "env": "CartPole-v0", + "run": "PG", + "stop": { + "training_iteration": args.num_iters, + }, + "config": { + "callbacks": { + "on_episode_start": tune.function(on_episode_start), + "on_episode_step": tune.function(on_episode_step), + "on_episode_end": tune.function(on_episode_end), + "on_sample_end": tune.function(on_sample_end), + "on_train_result": tune.function(on_train_result), + }, + }, + } + }) + + # verify custom metrics for integration tests + custom_metrics = trials[0].last_result["custom_metrics"] + print(custom_metrics) + assert "pole_angle_mean" in custom_metrics + assert "pole_angle_min" in custom_metrics + assert "pole_angle_max" in custom_metrics + assert type(custom_metrics["pole_angle_mean"]) is float + assert "callback_ok" in trials[0].last_result diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py deleted file mode 100644 index 9559648290dae..0000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py +++ /dev/null @@ -1,59 +0,0 @@ -""" Multiagent mountain car. Each agent outputs an action which -is summed to form the total action. This is a discrete -multiagent example -""" - -import gym -from gym.envs.registration import register - -import ray -import ray.rllib.agents.ppo as ppo -from ray.tune.registry import register_env - -env_name = "MultiAgentMountainCarEnv" - -env_version_num = 0 -env_name = env_name + '-v' + str(env_version_num) - - -def pass_params_to_gym(env_name): - global env_version_num - - register( - id=env_name, - entry_point=( - "ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:" - "MultiAgentMountainCarEnv"), - max_episode_steps=200, - kwargs={}) - - -def create_env(env_config): - pass_params_to_gym(env_name) - env = gym.envs.make(env_name) - return env - - -if __name__ == '__main__': - register_env(env_name, lambda env_config: create_env(env_config)) - config = ppo.DEFAULT_CONFIG.copy() - horizon = 10 - num_cpus = 4 - ray.init(num_cpus=num_cpus, redirect_output=True) - config["num_workers"] = num_cpus - config["train_batch_size"] = 1000 - config["num_sgd_iter"] = 10 - config["gamma"] = 0.999 - config["horizon"] = horizon - config["use_gae"] = False - config["model"].update({"fcnet_hiddens": [256, 256]}) - options = { - "multiagent_obs_shapes": [2, 2], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": False, - "multiagent_fcnet_hiddens": [[32, 32]] * 2 - } - config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, config=config) - for i in range(1): - alg.train() diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py deleted file mode 100644 index c120f00c99ec7..0000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py +++ /dev/null @@ -1,51 +0,0 @@ -from math import cos -from gym.spaces import Box, Tuple, Discrete -import numpy as np -from gym.envs.classic_control.mountain_car import MountainCarEnv -""" -Multiagent mountain car that sums and then -averages its actions to produce the velocity -""" - - -class MultiAgentMountainCarEnv(MountainCarEnv): - def __init__(self): - self.min_position = -1.2 - self.max_position = 0.6 - self.max_speed = 0.07 - self.goal_position = 0.5 - - self.low = np.array([self.min_position, -self.max_speed]) - self.high = np.array([self.max_position, self.max_speed]) - - self.viewer = None - - self.action_space = [Discrete(3) for _ in range(2)] - self.observation_space = Tuple( - [Box(self.low, self.high, dtype=np.float32) for _ in range(2)]) - - self.seed() - self.reset() - - def step(self, action): - summed_act = 0.5 * np.sum(action) - - position, velocity = self.state - velocity += (summed_act - 1) * 0.001 - velocity += cos(3 * position) * (-0.0025) - velocity = np.clip(velocity, -self.max_speed, self.max_speed) - position += velocity - position = np.clip(position, self.min_position, self.max_position) - if (position == self.min_position and velocity < 0): - velocity = 0 - - done = bool(position >= self.goal_position) - - reward = position - - self.state = (position, velocity) - return [np.array(self.state) for _ in range(2)], reward, done, {} - - def reset(self): - self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) - return [np.array(self.state) for _ in range(2)] diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py deleted file mode 100644 index b183ff2c0b157..0000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py +++ /dev/null @@ -1,60 +0,0 @@ -""" Run script for multiagent pendulum env. Each agent outputs a -torque which is summed to form the total torque. This is a -continuous multiagent example -""" - -import gym -from gym.envs.registration import register - -import ray -import ray.rllib.agents.ppo as ppo -from ray.tune.registry import register_env - -env_name = "MultiAgentPendulumEnv" - -env_version_num = 0 -env_name = env_name + '-v' + str(env_version_num) - - -def pass_params_to_gym(env_name): - global env_version_num - - register( - id=env_name, - entry_point=( - "ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:" - "MultiAgentPendulumEnv"), - max_episode_steps=100, - kwargs={}) - - -def create_env(env_config): - pass_params_to_gym(env_name) - env = gym.envs.make(env_name) - return env - - -if __name__ == '__main__': - register_env(env_name, lambda env_config: create_env(env_config)) - config = ppo.DEFAULT_CONFIG.copy() - horizon = 10 - num_cpus = 4 - ray.init(num_cpus=num_cpus, redirect_output=True) - config["num_workers"] = num_cpus - config["train_batch_size"] = 1000 - config["sgd_minibatch_size"] = 10 - config["num_sgd_iter"] = 10 - config["gamma"] = 0.999 - config["horizon"] = horizon - config["use_gae"] = True - config["model"].update({"fcnet_hiddens": [256, 256]}) - options = { - "multiagent_obs_shapes": [3, 3], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": True, - "multiagent_fcnet_hiddens": [[32, 32]] * 2 - } - config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, config=config) - for i in range(1): - alg.train() diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py deleted file mode 100644 index 02645832729f7..0000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py +++ /dev/null @@ -1,74 +0,0 @@ -from gym.spaces import Box, Tuple -from gym.utils import seeding -from gym.envs.classic_control.pendulum import PendulumEnv -import numpy as np -""" - Multiagent pendulum that sums its torques to generate an action -""" - - -class MultiAgentPendulumEnv(PendulumEnv): - metadata = { - 'render.modes': ['human', 'rgb_array'], - 'video.frames_per_second': 30 - } - - def __init__(self): - self.max_speed = 8 - self.max_torque = 2. - self.dt = .05 - self.viewer = None - - high = np.array([1., 1., self.max_speed]) - self.action_space = [ - Box(low=-self.max_torque / 2, - high=self.max_torque / 2, - shape=(1, ), - dtype=np.float32) for _ in range(2) - ] - self.observation_space = Tuple( - [Box(low=-high, high=high, dtype=np.float32) for _ in range(2)]) - - self.seed() - - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - - def step(self, u): - th, thdot = self.state # th := theta - - summed_u = np.sum(u) - g = 10. - m = 1. - length = 1. - dt = self.dt - - summed_u = np.clip(summed_u, -self.max_torque, self.max_torque) - self.last_u = summed_u # for rendering - costs = self.angle_normalize(th) ** 2 + .1 * thdot ** 2 + \ - .001 * (summed_u ** 2) - - newthdot = thdot + (-3 * g / (2 * length) * np.sin(th + np.pi) + 3. / - (m * length**2) * summed_u) * dt - newth = th + newthdot * dt - newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) - - self.state = np.array([newth, newthdot]) - return self._get_obs(), -costs, False, {} - - def reset(self): - high = np.array([np.pi, 1]) - self.state = self.np_random.uniform(low=-high, high=high) - self.last_u = None - return self._get_obs() - - def _get_obs(self): - theta, thetadot = self.state - return [ - np.array([np.cos(theta), np.sin(theta), thetadot]) - for _ in range(2) - ] - - def angle_normalize(self, x): - return (((x + np.pi) % (2 * np.pi)) - np.pi) diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index 8faeb184bf6d7..e2ab5270f9d87 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -16,9 +16,13 @@ import gym import random +import tensorflow as tf +import tensorflow.contrib.slim as slim + import ray from ray import tune -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.models import Model, ModelCatalog from ray.rllib.test.test_multi_agent_env import MultiCartpole from ray.tune import run_experiments from ray.tune.registry import register_env @@ -29,38 +33,82 @@ parser.add_argument("--num-policies", type=int, default=2) parser.add_argument("--num-iters", type=int, default=20) + +class CustomModel1(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + # Example of (optional) weight sharing between two different policies. + # Here, we share the variables defined in the 'shared' variable scope + # by entering it explicitly with tf.AUTO_REUSE. This creates the + # variables for the 'fc1' layer in a global scope called 'shared' + # outside of the policy's normal variable scope. + with tf.variable_scope( + tf.VariableScope(tf.AUTO_REUSE, "shared"), + reuse=tf.AUTO_REUSE, + auxiliary_name_scope=False): + last_layer = slim.fully_connected( + input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") + last_layer = slim.fully_connected( + last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") + output = slim.fully_connected( + last_layer, num_outputs, activation_fn=None, scope="fc_out") + return output, last_layer + + +class CustomModel2(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + # Weights shared with CustomModel1 + with tf.variable_scope( + tf.VariableScope(tf.AUTO_REUSE, "shared"), + reuse=tf.AUTO_REUSE, + auxiliary_name_scope=False): + last_layer = slim.fully_connected( + input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") + last_layer = slim.fully_connected( + last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") + output = slim.fully_connected( + last_layer, num_outputs, activation_fn=None, scope="fc_out") + return output, last_layer + + if __name__ == "__main__": args = parser.parse_args() ray.init() # Simple environment with `num_agents` independent cartpole entities register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents)) + ModelCatalog.register_custom_model("model1", CustomModel1) + ModelCatalog.register_custom_model("model2", CustomModel2) single_env = gym.make("CartPole-v0") obs_space = single_env.observation_space act_space = single_env.action_space - def gen_policy(): + # Each policy can have a different configuration (including custom model) + def gen_policy(i): config = { - "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]), - "n_step": random.choice([1, 2, 3, 4, 5]), + "model": { + "custom_model": ["model1", "model2"][i % 2], + }, + "gamma": random.choice([0.95, 0.99]), } - return (PGPolicyGraph, obs_space, act_space, config) + return (PPOPolicyGraph, obs_space, act_space, config) - # Setup PG with an ensemble of `num_policies` different policy graphs + # Setup PPO with an ensemble of `num_policies` different policy graphs policy_graphs = { - "policy_{}".format(i): gen_policy() + "policy_{}".format(i): gen_policy(i) for i in range(args.num_policies) } policy_ids = list(policy_graphs.keys()) run_experiments({ "test": { - "run": "PG", + "run": "PPO", "env": "multi_cartpole", "stop": { "training_iteration": args.num_iters }, "config": { + "log_level": "DEBUG", + "num_sgd_iter": 10, "multiagent": { "policy_graphs": policy_graphs, "policy_mapping_fn": tune.function( diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index e2c8bc97a8c23..46831db452b6d 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -57,7 +57,6 @@ def policy_mapping_fn(agent_id): "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["ppo_policy"], }, - "simple_optimizer": True, # disable filters, otherwise we would need to synchronize those # as well to the DQN agent "observation_filter": "NoFilter", diff --git a/python/ray/rllib/examples/parametric_action_cartpole.py b/python/ray/rllib/examples/parametric_action_cartpole.py new file mode 100644 index 0000000000000..a1438f0a24123 --- /dev/null +++ b/python/ray/rllib/examples/parametric_action_cartpole.py @@ -0,0 +1,196 @@ +"""Example of handling variable length and/or parametric action spaces. + +This is a toy example of the action-embedding based approach for handling large +discrete action spaces (potentially infinite in size), similar to how +OpenAI Five works: + + https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/ + +This currently works with RLlib's policy gradient style algorithms +(e.g., PG, PPO, IMPALA, A2C) and also DQN. + +Note that since the model outputs now include "-inf" tf.float32.min +values, not all algorithm options are supported at the moment. For example, +algorithms might crash if they don't properly ignore the -inf action scores. +Working configurations are given below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import random +import numpy as np +import gym +from gym.spaces import Box, Discrete, Dict +import tensorflow as tf +import tensorflow.contrib.slim as slim + +import ray +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.tune import run_experiments +from ray.tune.registry import register_env + +parser = argparse.ArgumentParser() +parser.add_argument("--stop", type=int, default=200) +parser.add_argument("--run", type=str, default="PPO") + + +class ParametricActionCartpole(gym.Env): + """Parametric action version of CartPole. + + In this env there are only ever two valid actions, but we pretend there are + actually up to `max_avail_actions` actions that can be taken, and the two + valid actions are randomly hidden among this set. + + At each step, we emit a dict of: + - the actual cart observation + - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail) + - the list of action embeddings (w/ zeroes for invalid actions) (e.g., + [[0, 0], + [0, 0], + [-0.2322, -0.2569], + [0, 0], + [0, 0], + [0.7878, 1.2297]] for max_avail_actions=6) + + In a real environment, the actions embeddings would be larger than two + units of course, and also there would be a variable number of valid actions + per step instead of always [LEFT, RIGHT]. + """ + + def __init__(self, max_avail_actions): + # Use simple random 2-unit action embeddings for [LEFT, RIGHT] + self.left_action_embed = np.random.randn(2) + self.right_action_embed = np.random.randn(2) + self.action_space = Discrete(max_avail_actions) + self.wrapped = gym.make("CartPole-v0") + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(max_avail_actions, )), + "avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)), + "cart": self.wrapped.observation_space, + }) + + def update_avail_actions(self): + self.action_assignments = [[0, 0]] * self.action_space.n + self.action_mask = [0] * self.action_space.n + self.left_idx, self.right_idx = random.sample( + range(self.action_space.n), 2) + self.action_assignments[self.left_idx] = self.left_action_embed + self.action_assignments[self.right_idx] = self.right_action_embed + self.action_mask[self.left_idx] = 1 + self.action_mask[self.right_idx] = 1 + + def reset(self): + self.update_avail_actions() + return { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": self.wrapped.reset(), + } + + def step(self, action): + if action == self.left_idx: + actual_action = 0 + elif action == self.right_idx: + actual_action = 1 + else: + raise ValueError( + "Chosen action was not one of the non-zero action embeddings", + action, self.action_assignments, self.action_mask, + self.left_idx, self.right_idx) + orig_obs, rew, done, info = self.wrapped.step(actual_action) + self.update_avail_actions() + obs = { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": orig_obs, + } + return obs, rew, done, info + + +class ParametricActionsModel(Model): + """Parametric action model that handles the dot product and masking. + + This assumes the outputs are logits for a single Categorical action dist. + Getting this to work with a more complex output (e.g., if the action space + is a tuple of several distributions) is also possible but left as an + exercise to the reader. + """ + + def _build_layers_v2(self, input_dict, num_outputs, options): + # Extract the available actions tensor from the observation. + avail_actions = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] + action_embed_size = avail_actions.shape[2].value + if num_outputs != avail_actions.shape[1].value: + raise ValueError( + "This model assumes num outputs is equal to max avail actions", + num_outputs, avail_actions) + + # Standard FC net component. + last_layer = input_dict["obs"]["cart"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + output = slim.fully_connected( + last_layer, + action_embed_size, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + + # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the + # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. + intent_vector = tf.expand_dims(output, 1) + + # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. + action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + ModelCatalog.register_custom_model("pa_model", ParametricActionsModel) + register_env("pa_cartpole", lambda _: ParametricActionCartpole(10)) + if args.run == "PPO": + cfg = { + "observation_filter": "NoFilter", # don't filter the action list + "vf_share_layers": True, # don't create duplicate value model + } + elif args.run == "DQN": + cfg = { + "hiddens": [], # don't postprocess the action scores + } + else: + cfg = {} + run_experiments({ + "parametric_cartpole": { + "run": args.run, + "env": "pa_cartpole", + "stop": { + "episode_reward_mean": args.stop, + }, + "config": dict({ + "model": { + "custom_model": "pa_model", + }, + "num_workers": 0, + }, **cfg), + }, + }) diff --git a/python/ray/rllib/examples/serving/cartpole_client.py b/python/ray/rllib/examples/serving/cartpole_client.py index 6f6a2e189c69a..b116eb9aa356b 100755 --- a/python/ray/rllib/examples/serving/cartpole_client.py +++ b/python/ray/rllib/examples/serving/cartpole_client.py @@ -29,7 +29,7 @@ if __name__ == "__main__": args = parser.parse_args() env = gym.make("CartPole-v0") - client = PolicyClient("http://localhost:8900") + client = PolicyClient("http://localhost:9900") eid = client.start_episode(training_enabled=not args.no_train) obs = env.reset() diff --git a/python/ray/rllib/examples/serving/cartpole_server.py b/python/ray/rllib/examples/serving/cartpole_server.py index dbbdf85809ff8..40260350ca3dd 100755 --- a/python/ray/rllib/examples/serving/cartpole_server.py +++ b/python/ray/rllib/examples/serving/cartpole_server.py @@ -14,19 +14,19 @@ import ray from ray.rllib.agents.dqn import DQNAgent -from ray.rllib.env.serving_env import ServingEnv +from ray.rllib.env.external_env import ExternalEnv from ray.rllib.utils.policy_server import PolicyServer from ray.tune.logger import pretty_print from ray.tune.registry import register_env SERVER_ADDRESS = "localhost" -SERVER_PORT = 8900 +SERVER_PORT = 9900 CHECKPOINT_FILE = "last_checkpoint.out" -class CartpoleServing(ServingEnv): +class CartpoleServing(ExternalEnv): def __init__(self): - ServingEnv.__init__( + ExternalEnv.__init__( self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32)) diff --git a/python/ray/rllib/examples/serving/test.sh b/python/ray/rllib/examples/serving/test.sh index d443a44a43223..d1dfa1e899c57 100755 --- a/python/ray/rllib/examples/serving/test.sh +++ b/python/ray/rllib/examples/serving/test.sh @@ -4,7 +4,7 @@ pkill -f cartpole_server.py (python cartpole_server.py 2>&1 | grep -v 200) & pid=$! -while ! curl localhost:8900; do +while ! curl localhost:9900; do sleep 1 done diff --git a/python/ray/rllib/models/__init__.py b/python/ray/rllib/models/__init__.py index ddfdd16b8ba18..52e47e807b3ff 100644 --- a/python/ray/rllib/models/__init__.py +++ b/python/ray/rllib/models/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS from ray.rllib.models.action_dist import (ActionDistribution, Categorical, DiagGaussian, Deterministic) from ray.rllib.models.model import Model @@ -7,6 +7,14 @@ from ray.rllib.models.lstm import LSTM __all__ = [ - "ActionDistribution", "Categorical", "DiagGaussian", "Deterministic", - "ModelCatalog", "Model", "Preprocessor", "FullyConnectedNetwork", "LSTM" + "ActionDistribution", + "Categorical", + "DiagGaussian", + "Deterministic", + "ModelCatalog", + "Model", + "Preprocessor", + "FullyConnectedNetwork", + "LSTM", + "MODEL_DEFAULTS", ] diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index b0cfe4141af16..f2a69efaf9b03 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -2,9 +2,12 @@ from __future__ import division from __future__ import print_function +from collections import namedtuple +import distutils.version import tensorflow as tf import numpy as np -import distutils.version + +from ray.rllib.utils.annotations import override use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion("1.5.0")) @@ -40,10 +43,12 @@ def sample(self): class Categorical(ActionDistribution): """Categorical distribution for discrete action spaces.""" + @override(ActionDistribution) def logp(self, x): return -tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.inputs, labels=x) + @override(ActionDistribution) def entropy(self): if use_tf150_api: a0 = self.inputs - tf.reduce_max( @@ -59,6 +64,7 @@ def entropy(self): p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) + @override(ActionDistribution) def kl(self, other): if use_tf150_api: a0 = self.inputs - tf.reduce_max( @@ -82,6 +88,7 @@ def kl(self, other): return tf.reduce_sum( p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1]) + @override(ActionDistribution) def sample(self): return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1) @@ -93,28 +100,21 @@ class DiagGaussian(ActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs, low=None, high=None): + def __init__(self, inputs): ActionDistribution.__init__(self, inputs) mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean - self.low = low - self.high = high - - # Squash to range if specified. - # TODO(ekl) might make sense to use a beta distribution instead: - # http://proceedings.mlr.press/v70/chou17a/chou17a.pdf - if low is not None: - self.mean = low + tf.sigmoid(self.mean) * (high - low) - self.log_std = log_std self.std = tf.exp(log_std) + @override(ActionDistribution) def logp(self, x): return (-0.5 * tf.reduce_sum( tf.square((x - self.mean) / self.std), reduction_indices=[1]) - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - tf.reduce_sum(self.log_std, reduction_indices=[1])) + @override(ActionDistribution) def kl(self, other): assert isinstance(other, DiagGaussian) return tf.reduce_sum( @@ -123,16 +123,15 @@ def kl(self, other): (2.0 * tf.square(other.std)) - 0.5, reduction_indices=[1]) + @override(ActionDistribution) def entropy(self): return tf.reduce_sum( .5 * self.log_std + .5 * np.log(2.0 * np.pi * np.e), reduction_indices=[1]) + @override(ActionDistribution) def sample(self): - out = self.mean + self.std * tf.random_normal(tf.shape(self.mean)) - if self.low is not None: - out = tf.clip_by_value(out, self.low, self.high) - return out + return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) class Deterministic(ActionDistribution): @@ -141,38 +140,11 @@ class Deterministic(ActionDistribution): This is similar to DiagGaussian with standard deviation zero. """ + @override(ActionDistribution) def sample(self): return self.inputs -def squash_to_range(dist_cls, low, high): - """Squashes an action distribution to a range in (low, high). - - Arguments: - dist_cls (class): ActionDistribution class to wrap. - low (float|array): Scalar value or array of values. - high (float|array): Scalar value or array of values. - """ - - class SquashToRangeWrapper(dist_cls): - def __init__(self, inputs): - dist_cls.__init__(self, inputs, low=low, high=high) - - def logp(self, x): - return dist_cls.logp(self, x) - - def kl(self, other): - return dist_cls.kl(self, other) - - def entropy(self): - return dist_cls.entropy(self) - - def sample(self): - return dist_cls.sample(self) - - return SquashToRangeWrapper - - class MultiActionDistribution(ActionDistribution): """Action distribution that operates for list of actions. @@ -188,8 +160,8 @@ def __init__(self, inputs, action_space, child_distributions, input_lens): child_list.append(distribution(split_inputs[i])) self.child_distributions = child_list + @override(ActionDistribution) def logp(self, x): - """The log-likelihood of the action distribution.""" split_indices = [] for dist in self.child_distributions: if isinstance(dist, Categorical): @@ -208,8 +180,8 @@ def logp(self, x): ]) return np.sum(log_list) + @override(ActionDistribution) def kl(self, other): - """The KL-divergence between two action distributions.""" kl_list = np.asarray([ distribution.kl(other_distribution) for distribution, other_distribution in zip( @@ -217,12 +189,15 @@ def kl(self, other): ]) return np.sum(kl_list) + @override(ActionDistribution) def entropy(self): - """The entropy of the action distribution.""" entropy_list = np.array( [s.entropy() for s in self.child_distributions]) return np.sum(entropy_list) + @override(ActionDistribution) def sample(self): - """Draw a sample from the action distribution.""" - return [[s.sample() for s in self.child_distributions]] + return TupleActions([s.sample() for s in self.child_distributions]) + + +TupleActions = namedtuple("TupleActions", ["batches"]) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index b98061fdd02a4..822af4a37e66f 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -3,6 +3,7 @@ from __future__ import print_function import gym +import logging import numpy as np import tensorflow as tf from functools import partial @@ -10,38 +11,67 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ _global_registry +from ray.rllib.env.async_vector_env import _ExternalEnvToAsync +from ray.rllib.env.external_env import ExternalEnv +from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models.action_dist import ( - Categorical, Deterministic, DiagGaussian, MultiActionDistribution, - squash_to_range) + Categorical, Deterministic, DiagGaussian, MultiActionDistribution) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM -from ray.rllib.models.multiagentfcnet import MultiAgentFullyConnectedNetwork -MODEL_CONFIGS = [ +logger = logging.getLogger(__name__) + +# yapf: disable +# __sphinx_doc_begin__ +MODEL_DEFAULTS = { # === Built-in options === # Filter config. List of [out_channels, kernel, stride] for each filter - "conv_filters", - "conv_activation", # Nonlinearity for built-in convnet - "fcnet_activation", # Nonlinearity for fully connected net (tanh, relu) - "fcnet_hiddens", # Number of hidden layers for fully connected net - "dim", # Dimension for ATARI - "grayscale", # Converts ATARI frame to 1 Channel Grayscale image - "zero_mean", # Changes frame to range from [-1, 1] if true - "extra_frameskip", # (int) for number of frames to skip - "free_log_std", # Documented in ray.rllib.models.Model - "channel_major", # Pytorch conv requires images to be channel-major - "squash_to_range", # Whether to squash the action output to space range - "use_lstm", # Whether to wrap the model with a LSTM - "max_seq_len", # Max seq len for training the LSTM, defaults to 20 - "lstm_cell_size", # Size of the LSTM cell + "conv_filters": None, + # Nonlinearity for built-in convnet + "conv_activation": "relu", + # Nonlinearity for fully connected net (tanh, relu) + "fcnet_activation": "tanh", + # Number of hidden layers for fully connected net + "fcnet_hiddens": [256, 256], + # For control envs, documented in ray.rllib.models.Model + "free_log_std": False, + # (deprecated) Whether to use sigmoid to squash actions to space range + "squash_to_range": False, + + # == LSTM == + # Whether to wrap the model with a LSTM + "use_lstm": False, + # Max seq len for training the LSTM, defaults to 20 + "max_seq_len": 20, + # Size of the LSTM cell + "lstm_cell_size": 256, + # Whether to feed a_{t-1}, r_{t-1} to LSTM + "lstm_use_prev_action_reward": False, + + # == Atari == + # Whether to enable framestack for Atari envs + "framestack": True, + # Final resized frame dimension + "dim": 84, + # Pytorch conv requires images to be channel-major + "channel_major": False, + # (deprecated) Converts ATARI frame to 1 Channel Grayscale image + "grayscale": False, + # (deprecated) Changes frame to range from [-1, 1] if true + "zero_mean": True, # === Options for custom models === - "custom_preprocessor", # Name of a custom preprocessor to use - "custom_model", # Name of a custom model to use - "custom_options", # Extra options to pass to the custom classes -] + # Name of a custom preprocessor to use + "custom_preprocessor": None, + # Name of a custom model to use + "custom_model": None, + # Extra options to pass to the custom classes + "custom_options": {}, +} +# __sphinx_doc_end__ +# yapf: enable class ModelCatalog(object): @@ -51,14 +81,15 @@ class ModelCatalog(object): >>> prep = ModelCatalog.get_preprocessor(env) >>> observation = prep.transform(raw_observation) - >>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space) - >>> model = ModelCatalog.get_model(inputs, dist_dim) + >>> dist_cls, dist_dim = ModelCatalog.get_action_dist( + env.action_space, {}) + >>> model = ModelCatalog.get_model(inputs, dist_dim, options) >>> dist = dist_cls(model.outputs) >>> action = dist.sample() """ @staticmethod - def get_action_dist(action_space, config=None, dist_type=None): + def get_action_dist(action_space, config, dist_type=None): """Returns action distribution class and size for the given action space. Args: @@ -71,18 +102,22 @@ def get_action_dist(action_space, config=None, dist_type=None): dist_dim (int): The size of the input vector to the distribution. """ - # TODO(ekl) are list spaces valid? - if isinstance(action_space, list): - action_space = gym.spaces.Tuple(action_space) - config = config or {} + config = config or MODEL_DEFAULTS if isinstance(action_space, gym.spaces.Box): + if len(action_space.shape) > 1: + raise ValueError( + "Action space has multiple dimensions " + "{}. ".format(action_space.shape) + + "Consider reshaping this into a single dimension, " + "using a Tuple action space, or the multi-agent API.") if dist_type is None: dist = DiagGaussian if config.get("squash_to_range"): - dist = squash_to_range(dist, action_space.low, - action_space.high) + raise ValueError( + "The squash_to_range option is deprecated. See the " + "clip_actions agent option instead.") return dist, action_space.shape[0] * 2 - elif dist_type == 'deterministic': + elif dist_type == "deterministic": return Deterministic, action_space.shape[0] elif isinstance(action_space, gym.spaces.Discrete): return Categorical, action_space.n @@ -90,7 +125,8 @@ def get_action_dist(action_space, config=None, dist_type=None): child_dist = [] input_lens = [] for action in action_space.spaces: - dist, action_size = ModelCatalog.get_action_dist(action) + dist, action_size = ModelCatalog.get_action_dist( + action, config) child_dist.append(dist) input_lens.append(action_size) return partial( @@ -112,10 +148,6 @@ def get_action_placeholder(action_space): action_placeholder (Tensor): A placeholder for the actions """ - # TODO(ekl) are list spaces valid? - if isinstance(action_space, list): - action_space = gym.spaces.Tuple(action_space) - if isinstance(action_space, gym.spaces.Box): return tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name="action") @@ -139,62 +171,71 @@ def get_action_placeholder(action_space): " not supported".format(action_space)) @staticmethod - def get_model(inputs, + def get_model(input_dict, + obs_space, num_outputs, - options=None, + options, state_in=None, seq_lens=None): """Returns a suitable model conforming to given input and output specs. Args: - inputs (Tensor): The input tensor to the model. + input_dict (dict): Dict of input tensors to the model, including + the observation under the "obs" key. + obs_space (Space): Observation space of the target gym env. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. state_in (list): Optional RNN state in tensors. seq_in (Tensor): Optional RNN sequence length tensor. Returns: - model (Model): Neural network model. + model (models.Model): Neural network model. """ - options = options or {} - model = ModelCatalog._get_model(inputs, num_outputs, options, state_in, - seq_lens) + assert isinstance(input_dict, dict) + options = options or MODEL_DEFAULTS + model = ModelCatalog._get_model(input_dict, obs_space, num_outputs, + options, state_in, seq_lens) if options.get("use_lstm"): - model = LSTM(model.last_layer, num_outputs, options, state_in, + copy = dict(input_dict) + copy["obs"] = model.last_layer + feature_space = gym.spaces.Box( + -1, 1, shape=(model.last_layer.shape[1], )) + model = LSTM(copy, feature_space, num_outputs, options, state_in, seq_lens) + logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format( + model, input_dict, obs_space, state_in, seq_lens, model.outputs, + model.state_out)) + + model._validate_output_shape() return model @staticmethod - def _get_model(inputs, num_outputs, options, state_in, seq_lens): - if "custom_model" in options: + def _get_model(input_dict, obs_space, num_outputs, options, state_in, + seq_lens): + if options.get("custom_model"): model = options["custom_model"] - print("Using custom model {}".format(model)) + logger.debug("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( - inputs, + input_dict, + obs_space, num_outputs, options, state_in=state_in, seq_lens=seq_lens) - obs_rank = len(inputs.shape) - 1 - - # num_outputs > 1 used to avoid hitting this with the value function - if isinstance( - options.get("custom_options", {}).get( - "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1: - return MultiAgentFullyConnectedNetwork(inputs, num_outputs, - options) + obs_rank = len(input_dict["obs"].shape) - 1 if obs_rank > 1: - return VisionNetwork(inputs, num_outputs, options) + return VisionNetwork(input_dict, obs_space, num_outputs, options) - return FullyConnectedNetwork(inputs, num_outputs, options) + return FullyConnectedNetwork(input_dict, obs_space, num_outputs, + options) @staticmethod - def get_torch_model(input_shape, num_outputs, options={}): + def get_torch_model(input_shape, num_outputs, options=None): """Returns a PyTorch suitable model. This is currently only supported in A3C. @@ -204,16 +245,17 @@ def get_torch_model(input_shape, num_outputs, options={}): options (dict): Optional args to pass to the model constructor. Returns: - model (Model): Neural network model. + model (models.Model): Neural network model. """ from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as PyTorchFCNet) from ray.rllib.models.pytorch.visionnet import (VisionNetwork as PyTorchVisionNet) - if "custom_model" in options: + options = options or MODEL_DEFAULTS + if options.get("custom_model"): model = options["custom_model"] - print("Using custom torch model {}".format(model)) + logger.info("Using custom torch model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( input_shape, num_outputs, options) @@ -228,44 +270,68 @@ def get_torch_model(input_shape, num_outputs, options={}): return PyTorchFCNet(input_shape[0], num_outputs, options) @staticmethod - def get_preprocessor(env, options={}): - """Returns a suitable processor for the given environment. + def get_preprocessor(env, options=None): + """Returns a suitable preprocessor for the given env. + + This is a wrapper for get_preprocessor_for_space(). + """ + + return ModelCatalog.get_preprocessor_for_space(env.observation_space, + options) + + @staticmethod + def get_preprocessor_for_space(observation_space, options=None): + """Returns a suitable preprocessor for the given observation space. Args: - env (gym.Env): The gym environment to preprocess. + observation_space (Space): The input observation space. options (dict): Options to pass to the preprocessor. Returns: - preprocessor (Preprocessor): Preprocessor for the env observations. + preprocessor (Preprocessor): Preprocessor for the observations. """ + + options = options or MODEL_DEFAULTS for k in options.keys(): - if k not in MODEL_CONFIGS: + if k not in MODEL_DEFAULTS: raise Exception("Unknown config key `{}`, all keys: {}".format( - k, MODEL_CONFIGS)) + k, list(MODEL_DEFAULTS))) - if "custom_preprocessor" in options: + if options.get("custom_preprocessor"): preprocessor = options["custom_preprocessor"] - print("Using custom preprocessor {}".format(preprocessor)) - return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( - env.observation_space, options) + logger.info("Using custom preprocessor {}".format(preprocessor)) + prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( + observation_space, options) + else: + cls = get_preprocessor(observation_space) + prep = cls(observation_space, options) - preprocessor = get_preprocessor(env.observation_space) - return preprocessor(env.observation_space, options) + logger.debug("Created preprocessor {}: {} -> {}".format( + prep, observation_space, prep.shape)) + return prep @staticmethod - def get_preprocessor_as_wrapper(env, options={}): + def get_preprocessor_as_wrapper(env, options=None): """Returns a preprocessor as a gym observation wrapper. Args: - env (gym.Env): The gym environment to wrap. + env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap. options (dict): Options to pass to the preprocessor. Returns: - wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. + env (RLlib env): Wrapped environment """ + options = options or MODEL_DEFAULTS preprocessor = ModelCatalog.get_preprocessor(env, options) - return _RLlibPreprocessorWrapper(env, preprocessor) + if isinstance(env, gym.Env): + return _RLlibPreprocessorWrapper(env, preprocessor) + elif isinstance(env, VectorEnv): + return _RLlibVectorPreprocessorWrapper(env, preprocessor) + elif isinstance(env, ExternalEnv): + return _ExternalEnvToAsync(env, preprocessor) + else: + raise ValueError("Don't know how to wrap {}".format(env)) @staticmethod def register_custom_preprocessor(preprocessor_name, preprocessor_class): @@ -301,10 +367,32 @@ class _RLlibPreprocessorWrapper(gym.ObservationWrapper): def __init__(self, env, preprocessor): super(_RLlibPreprocessorWrapper, self).__init__(env) self.preprocessor = preprocessor - - from gym.spaces.box import Box - self.observation_space = Box( - -1.0, 1.0, preprocessor.shape, dtype=np.float32) + self.observation_space = preprocessor.observation_space def observation(self, observation): return self.preprocessor.transform(observation) + + +class _RLlibVectorPreprocessorWrapper(VectorEnv): + """Preprocessing wrapper for vector envs.""" + + def __init__(self, env, preprocessor): + self.env = env + self.prep = preprocessor + self.action_space = env.action_space + self.observation_space = preprocessor.observation_space + self.num_envs = env.num_envs + + def vector_reset(self): + return [self.prep.transform(obs) for obs in self.env.vector_reset()] + + def reset_at(self, index): + return self.prep.transform(self.env.reset_at(index)) + + def vector_step(self, actions): + obs, rewards, dones, infos = self.env.vector_step(actions) + obs = [self.prep.transform(o) for o in obs] + return obs, rewards, dones, infos + + def get_unwrapped(self): + return self.env.get_unwrapped() diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 11aee2c0da8f4..19745b9e7a3ca 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -7,14 +7,22 @@ from ray.rllib.models.model import Model from ray.rllib.models.misc import normc_initializer, get_activation_fn +from ray.rllib.utils.annotations import override class FullyConnectedNetwork(Model): """Generic fully connected network.""" + @override(Model) def _build_layers(self, inputs, num_outputs, options): - hiddens = options.get("fcnet_hiddens", [256, 256]) - activation = get_activation_fn(options.get("fcnet_activation", "tanh")) + """Process the flattened inputs. + + Note that dict inputs will be flattened into a vector. To define a + model that processes the components separately, use _build_layers_v2(). + """ + + hiddens = options.get("fcnet_hiddens") + activation = get_activation_fn(options.get("fcnet_activation")) with tf.name_scope("fc_net"): i = 1 diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 581569f0eff0c..323c7f375bb3f 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -9,6 +9,10 @@ postprocessing, we dynamically pad the experience batches so that this reshaping is possible. +Note that this padding strategy only works out if we assume zero inputs don't +meaningfully affect the loss function. This happens to be true for all the +current algorithms: https://github.com/ray-project/ray/issues/2992 + See the add_time_dimension() and chop_into_sequences() functions below for more info. """ @@ -19,6 +23,72 @@ from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model +from ray.rllib.utils.annotations import override + + +class LSTM(Model): + """Adds a LSTM cell on top of some other model output. + + Uses a linear layer at the end for output. + + Important: we assume inputs is a padded batch of sequences denoted by + self.seq_lens. See add_time_dimension() for more information. + """ + + @override(Model) + def _build_layers_v2(self, input_dict, num_outputs, options): + cell_size = options.get("lstm_cell_size") + if options.get("lstm_use_prev_action_reward"): + action_dim = int( + np.product( + input_dict["prev_actions"].get_shape().as_list()[1:])) + features = tf.concat( + [ + input_dict["obs"], + tf.reshape( + tf.cast(input_dict["prev_actions"], tf.float32), + [-1, action_dim]), + tf.reshape(input_dict["prev_rewards"], [-1, 1]), + ], + axis=1) + else: + features = input_dict["obs"] + last_layer = add_time_dimension(features, self.seq_lens) + + # Setup the LSTM cell + lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) + self.state_init = [ + np.zeros(lstm.state_size.c, np.float32), + np.zeros(lstm.state_size.h, np.float32) + ] + + # Setup LSTM inputs + if self.state_in: + c_in, h_in = self.state_in + else: + c_in = tf.placeholder( + tf.float32, [None, lstm.state_size.c], name="c") + h_in = tf.placeholder( + tf.float32, [None, lstm.state_size.h], name="h") + self.state_in = [c_in, h_in] + + # Setup LSTM outputs + state_in = rnn.LSTMStateTuple(c_in, h_in) + lstm_out, lstm_state = tf.nn.dynamic_rnn( + lstm, + last_layer, + initial_state=state_in, + sequence_length=self.seq_lens, + time_major=False, + dtype=tf.float32) + + self.state_out = list(lstm_state) + + # Compute outputs + last_layer = tf.reshape(lstm_out, [-1, cell_size]) + logits = linear(last_layer, num_outputs, "action", + normc_initializer(0.01)) + return logits, last_layer def add_time_dimension(padded_inputs, seq_lens): @@ -48,15 +118,24 @@ def add_time_dimension(padded_inputs, seq_lens): return tf.reshape(padded_inputs, new_shape) -def chop_into_sequences(episode_ids, feature_columns, state_columns, - max_seq_len): +def chop_into_sequences(episode_ids, + agent_indices, + feature_columns, + state_columns, + max_seq_len, + dynamic_max=True): """Truncate and pad experiences into fixed-length sequences. Arguments: episode_ids (list): List of episode ids for each step. + agent_indices (list): List of agent ids for each step. Note that this + has to be combined with episode_ids for uniqueness. feature_columns (list): List of arrays containing features. state_columns (list): List of arrays containing LSTM state values. max_seq_len (int): Max length of sequences before truncation. + dynamic_max (bool): Whether to dynamically shrink the max seq len. + For example, if max len is 20 and the actual max seq len in the + data is 7, it will be shrunk to 7. Returns: f_pad (list): Padded feature columns. These will be of shape @@ -84,19 +163,21 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns, prev_id = None seq_lens = [] seq_len = 0 - for eps_id in episode_ids: - if (prev_id is not None and eps_id != prev_id) or \ + unique_ids = np.add(episode_ids, agent_indices) + for uid in unique_ids: + if (prev_id is not None and uid != prev_id) or \ seq_len >= max_seq_len: seq_lens.append(seq_len) seq_len = 0 seq_len += 1 - prev_id = eps_id + prev_id = uid if seq_len: seq_lens.append(seq_len) - assert sum(seq_lens) == len(episode_ids) + assert sum(seq_lens) == len(unique_ids) # Dynamically shrink max len as needed to optimize memory usage - max_seq_len = max(seq_lens) + if dynamic_max: + max_seq_len = max(seq_lens) feature_sequences = [] for f in feature_columns: @@ -109,7 +190,7 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns, f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len - assert i == len(episode_ids), f + assert i == len(unique_ids), f feature_sequences.append(f_pad) initial_states = [] @@ -123,52 +204,3 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns, initial_states.append(np.array(s_init)) return feature_sequences, initial_states, np.array(seq_lens) - - -class LSTM(Model): - """Adds a LSTM cell on top of some other model output. - - Uses a linear layer at the end for output. - - Important: we assume inputs is a padded batch of sequences denoted by - self.seq_lens. See add_time_dimension() for more information. - """ - - def _build_layers(self, inputs, num_outputs, options): - cell_size = options.get("lstm_cell_size", 256) - last_layer = add_time_dimension(inputs, self.seq_lens) - - # Setup the LSTM cell - lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) - self.state_init = [ - np.zeros(lstm.state_size.c, np.float32), - np.zeros(lstm.state_size.h, np.float32) - ] - - # Setup LSTM inputs - if self.state_in: - c_in, h_in = self.state_in - else: - c_in = tf.placeholder( - tf.float32, [None, lstm.state_size.c], name="c") - h_in = tf.placeholder( - tf.float32, [None, lstm.state_size.h], name="h") - self.state_in = [c_in, h_in] - - # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) - lstm_out, lstm_state = tf.nn.dynamic_rnn( - lstm, - last_layer, - initial_state=state_in, - sequence_length=self.seq_lens, - time_major=False, - dtype=tf.float32) - - self.state_out = list(lstm_state) - - # Compute outputs - last_layer = tf.reshape(lstm_out, [-1, cell_size]) - logits = linear(last_layer, num_outputs, "action", - normc_initializer(0.01)) - return logits, last_layer diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 00d6575e62104..818966bb12e1a 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -2,8 +2,14 @@ from __future__ import division from __future__ import print_function +from collections import OrderedDict + +import gym import tensorflow as tf +from ray.rllib.models.misc import linear, normc_initializer +from ray.rllib.models.preprocessors import get_preprocessor + class Model(object): """Defines an abstract network model for use with RLlib. @@ -16,12 +22,12 @@ class Model(object): needs to further post-processing (e.g. Actor and Critic networks in A3C). Attributes: - inputs (Tensor): The input placeholder for this model, of shape - [BATCH_SIZE, ...]. + input_dict (dict): Dictionary of input tensors, including "obs", + "prev_action", "prev_reward", "is_training". outputs (Tensor): The output vector of this model, of shape [BATCH_SIZE, num_outputs]. - last_layer (Tensor): The network layer right before the model output, - of shape [BATCH_SIZE, N]. + last_layer (Tensor): The feature layer right before the model output, + of shape [BATCH_SIZE, f]. state_init (list): List of initial recurrent state tensors (if any). state_in (list): List of input recurrent state tensors (if any). state_out (list): List of output recurrent state tensors (if any). @@ -38,12 +44,13 @@ class Model(object): """ def __init__(self, - inputs, + input_dict, + obs_space, num_outputs, options, state_in=None, seq_lens=None): - self.inputs = inputs + assert isinstance(input_dict, dict), input_dict # Default attribute values for the non-RNN case self.state_init = [] @@ -55,11 +62,18 @@ def __init__(self, self.seq_lens = tf.placeholder( dtype=tf.int32, shape=[None], name="seq_lens") - if options.get("free_log_std", False): + self._num_outputs = num_outputs + if options.get("free_log_std"): assert num_outputs % 2 == 0 num_outputs = num_outputs // 2 - self.outputs, self.last_layer = self._build_layers( - inputs, num_outputs, options) + try: + self.outputs, self.last_layer = self._build_layers_v2( + _restore_original_dimensions(input_dict, obs_space), + num_outputs, options) + except NotImplementedError: + self.outputs, self.last_layer = self._build_layers( + input_dict["obs"], num_outputs, options) + if options.get("free_log_std", False): log_std = tf.get_variable( name="log_std", @@ -68,6 +82,118 @@ def __init__(self, self.outputs = tf.concat( [self.outputs, 0.0 * self.outputs + log_std], 1) - def _build_layers(self): - """Builds and returns the output and last layer of the network.""" + def _build_layers(self, inputs, num_outputs, options): + """Builds and returns the output and last layer of the network. + + Deprecated: use _build_layers_v2 instead, which has better support + for dict and tuple spaces. + """ + raise NotImplementedError + + def _build_layers_v2(self, input_dict, num_outputs, options): + """Define the layers of a custom model. + + Arguments: + input_dict (dict): Dictionary of input tensors, including "obs", + "prev_action", "prev_reward", "is_training". + num_outputs (int): Output tensor must be of size + [BATCH_SIZE, num_outputs]. + options (dict): Model options. + + Returns: + (outputs, feature_layer): Tensors of size [BATCH_SIZE, num_outputs] + and [BATCH_SIZE, desired_feature_size]. + + When using dict or tuple observation spaces, you can access + the nested sub-observation batches here as well: + + Examples: + >>> print(input_dict) + {'prev_actions': , + 'prev_rewards': , + 'is_training': , + 'obs': OrderedDict([ + ('sensors', OrderedDict([ + ('front_cam', [ + , + ]), + ('position', ), + ('velocity', )]))])} + """ raise NotImplementedError + + def value_function(self): + """Builds the value function output. + + This method can be overridden to customize the implementation of the + value function (e.g., not sharing hidden layers). + + Returns: + Tensor of size [BATCH_SIZE] for the value function. + """ + return tf.reshape( + linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1]) + + def loss(self): + """Builds any built-in (self-supervised) loss for the model. + + For example, this can be used to incorporate auto-encoder style losses. + Note that this loss has to be included in the policy graph loss to have + an effect (done for built-in algorithms). + + Returns: + Scalar tensor for the self-supervised loss. + """ + return tf.constant(0.0) + + def _validate_output_shape(self): + """Checks that the model has the correct number of outputs.""" + try: + out = tf.convert_to_tensor(self.outputs) + shape = out.shape.as_list() + except Exception: + raise ValueError("Output is not a tensor: {}".format(self.outputs)) + else: + if len(shape) != 2 or shape[1] != self._num_outputs: + raise ValueError( + "Expected output shape of [None, {}], got {}".format( + self._num_outputs, shape)) + + +def _restore_original_dimensions(input_dict, obs_space): + if hasattr(obs_space, "original_space"): + return dict( + input_dict, + obs=_unpack_obs(input_dict["obs"], obs_space.original_space)) + return input_dict + + +def _unpack_obs(obs, space): + if (isinstance(space, gym.spaces.Dict) + or isinstance(space, gym.spaces.Tuple)): + prep = get_preprocessor(space)(space) + if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]: + raise ValueError( + "Expected flattened obs shape of [None, {}], got {}".format( + prep.shape[0], obs.shape)) + assert len(prep.preprocessors) == len(space.spaces), \ + (len(prep.preprocessors) == len(space.spaces)) + offset = 0 + if isinstance(space, gym.spaces.Tuple): + u = [] + for p, v in zip(prep.preprocessors, space.spaces): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u.append( + _unpack_obs( + tf.reshape(obs_slice, [-1] + list(p.shape)), v)) + else: + u = OrderedDict() + for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u[k] = _unpack_obs( + tf.reshape(obs_slice, [-1] + list(p.shape)), v) + return u + else: + return obs diff --git a/python/ray/rllib/models/multiagentfcnet.py b/python/ray/rllib/models/multiagentfcnet.py deleted file mode 100644 index dad7f29831035..0000000000000 --- a/python/ray/rllib/models/multiagentfcnet.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from ray.rllib.models.model import Model -from ray.rllib.models.fcnet import FullyConnectedNetwork -from ray.rllib.utils.reshaper import Reshaper - - -class MultiAgentFullyConnectedNetwork(Model): - """Multiagent fully connected network.""" - - def _build_layers(self, inputs, num_outputs, options): - # Split the input and output tensors - input_shapes = options["custom_options"]["multiagent_obs_shapes"] - output_shapes = options["custom_options"]["multiagent_act_shapes"] - input_reshaper = Reshaper(input_shapes) - output_reshaper = Reshaper(output_shapes) - split_inputs = input_reshaper.split_tensor(inputs) - num_actions = output_reshaper.split_number(num_outputs) - - custom_options = options["custom_options"] - hiddens = custom_options.get("multiagent_fcnet_hiddens", - [[256, 256]] * 1) - - # check for a shared model - shared_model = custom_options.get("multiagent_shared_model", 0) - reuse = tf.AUTO_REUSE if shared_model else False - outputs = [] - for i in range(len(hiddens)): - scope = "multi" if shared_model else "multi{}".format(i) - with tf.variable_scope(scope, reuse=reuse): - sub_options = options.copy() - sub_options.update({"fcnet_hiddens": hiddens[i]}) - # TODO(ev) make this support arbitrary networks - fcnet = FullyConnectedNetwork(split_inputs[i], - int(num_actions[i]), sub_options) - output = fcnet.outputs - outputs.append(output) - overall_output = tf.concat(outputs, axis=1) - return overall_output, outputs diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index c400dd9805d34..0238ef2d8d889 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -1,13 +1,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from collections import OrderedDict import cv2 +import logging import numpy as np import gym +from ray.rllib.utils.annotations import override + ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128, ) +logger = logging.getLogger(__name__) + class Preprocessor(object): """Defines an abstract observation preprocessor function. @@ -16,35 +23,59 @@ class Preprocessor(object): shape (obj): Shape of the preprocessed output. """ - def __init__(self, obs_space, options): + def __init__(self, obs_space, options=None): legacy_patch_shapes(obs_space) self._obs_space = obs_space - self._options = options - self._init() + self._options = options or {} + self.shape = self._init_shape(obs_space, options) - def _init(self): - pass + def _init_shape(self, obs_space, options): + """Returns the shape after preprocessing.""" + raise NotImplementedError def transform(self, observation): """Returns the preprocessed observation.""" raise NotImplementedError + @property + def size(self): + return int(np.product(self.shape)) + + @property + def observation_space(self): + obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32) + # Stash the unwrapped space so that we can unwrap dict and tuple spaces + # automatically in model.py + if (isinstance(self, TupleFlatteningPreprocessor) + or isinstance(self, DictFlatteningPreprocessor)): + obs_space.original_space = self._obs_space + return obs_space + -class AtariPixelPreprocessor(Preprocessor): - def _init(self): - self._grayscale = self._options.get("grayscale", False) - self._zero_mean = self._options.get("zero_mean", True) - self._dim = self._options.get("dim", 84) - self._channel_major = self._options.get("channel_major", False) +class GenericPixelPreprocessor(Preprocessor): + """Generic image preprocessor. + + Note: for Atari games, use config {"preprocessor_pref": "deepmind"} + instead for deepmind-style Atari preprocessing. + """ + + @override(Preprocessor) + def _init_shape(self, obs_space, options): + self._grayscale = options.get("grayscale") + self._zero_mean = options.get("zero_mean") + self._dim = options.get("dim") + self._channel_major = options.get("channel_major") if self._grayscale: - self.shape = (self._dim, self._dim, 1) + shape = (self._dim, self._dim, 1) else: - self.shape = (self._dim, self._dim, 3) + shape = (self._dim, self._dim, 3) # channel_major requires (# in-channels, row dim, col dim) if self._channel_major: - self.shape = self.shape[-1:] + self.shape[:-1] + shape = shape[-1:] + shape[:-1] + return shape + @override(Preprocessor) def transform(self, observation): """Downsamples images from (210, 160, 3) by the configured factor.""" scaled = observation[25:-25, :, :] @@ -69,27 +100,36 @@ def transform(self, observation): class AtariRamPreprocessor(Preprocessor): - def _init(self): - self.shape = (128, ) + @override(Preprocessor) + def _init_shape(self, obs_space, options): + return (128, ) + @override(Preprocessor) def transform(self, observation): return (observation - 128) / 128 class OneHotPreprocessor(Preprocessor): - def _init(self): - self.shape = (self._obs_space.n, ) + @override(Preprocessor) + def _init_shape(self, obs_space, options): + return (self._obs_space.n, ) + @override(Preprocessor) def transform(self, observation): arr = np.zeros(self._obs_space.n) + if not self._obs_space.contains(observation): + raise ValueError("Observation outside expected value range", + self._obs_space, observation) arr[observation] = 1 return arr class NoPreprocessor(Preprocessor): - def _init(self): - self.shape = self._obs_space.shape + @override(Preprocessor) + def _init_shape(self, obs_space, options): + return self._obs_space.shape + @override(Preprocessor) def transform(self, observation): return observation @@ -97,30 +137,61 @@ def transform(self, observation): class TupleFlatteningPreprocessor(Preprocessor): """Preprocesses each tuple element, then flattens it all into a vector. - If desired, the vector output can be unpacked via tf.reshape() within a - custom model to handle each component separately. + RLlib models will unpack the flattened output before _build_layers_v2(). """ - def _init(self): + @override(Preprocessor) + def _init_shape(self, obs_space, options): assert isinstance(self._obs_space, gym.spaces.Tuple) size = 0 self.preprocessors = [] for i in range(len(self._obs_space.spaces)): space = self._obs_space.spaces[i] - print("Creating sub-preprocessor for", space) + logger.debug("Creating sub-preprocessor for {}".format(space)) preprocessor = get_preprocessor(space)(space, self._options) self.preprocessors.append(preprocessor) - size += np.product(preprocessor.shape) - self.shape = (size, ) + size += preprocessor.size + return (size, ) + @override(Preprocessor) def transform(self, observation): assert len(observation) == len(self.preprocessors), observation return np.concatenate([ - np.reshape(p.transform(o), [np.product(p.shape)]) + np.reshape(p.transform(o), [p.size]) for (o, p) in zip(observation, self.preprocessors) ]) +class DictFlatteningPreprocessor(Preprocessor): + """Preprocesses each dict value, then flattens it all into a vector. + + RLlib models will unpack the flattened output before _build_layers_v2(). + """ + + @override(Preprocessor) + def _init_shape(self, obs_space, options): + assert isinstance(self._obs_space, gym.spaces.Dict) + size = 0 + self.preprocessors = [] + for space in self._obs_space.spaces.values(): + logger.debug("Creating sub-preprocessor for {}".format(space)) + preprocessor = get_preprocessor(space)(space, self._options) + self.preprocessors.append(preprocessor) + size += preprocessor.size + return (size, ) + + @override(Preprocessor) + def transform(self, observation): + if not isinstance(observation, OrderedDict): + observation = OrderedDict(sorted(list(observation.items()))) + assert len(observation) == len(self.preprocessors), \ + (len(observation), len(self.preprocessors)) + return np.concatenate([ + np.reshape(p.transform(o), [p.size]) + for (o, p) in zip(observation.values(), self.preprocessors) + ]) + + def get_preprocessor(space): """Returns an appropriate preprocessor class for the given space.""" @@ -130,11 +201,13 @@ def get_preprocessor(space): if isinstance(space, gym.spaces.Discrete): preprocessor = OneHotPreprocessor elif obs_shape == ATARI_OBS_SHAPE: - preprocessor = AtariPixelPreprocessor + preprocessor = GenericPixelPreprocessor elif obs_shape == ATARI_RAM_OBS_SHAPE: preprocessor = AtariRamPreprocessor elif isinstance(space, gym.spaces.Tuple): preprocessor = TupleFlatteningPreprocessor + elif isinstance(space, gym.spaces.Dict): + preprocessor = DictFlatteningPreprocessor else: preprocessor = NoPreprocessor diff --git a/python/ray/rllib/models/pytorch/fcnet.py b/python/ray/rllib/models/pytorch/fcnet.py index e8f50da2fb340..f69cb7ca21d45 100644 --- a/python/ray/rllib/models/pytorch/fcnet.py +++ b/python/ray/rllib/models/pytorch/fcnet.py @@ -2,10 +2,14 @@ from __future__ import division from __future__ import print_function +import logging + from ray.rllib.models.pytorch.model import Model, SlimFC from ray.rllib.models.pytorch.misc import normc_initializer import torch.nn as nn +logger = logging.getLogger(__name__) + class FullyConnectedNetwork(Model): """TODO(rliaw): Logits, Value should both be contained here""" @@ -19,7 +23,7 @@ def _build_layers(self, inputs, num_outputs, options): activation = nn.Tanh elif fcnet_activation == "relu": activation = nn.ReLU - print("Constructing fcnet {} {}".format(hiddens, activation)) + logger.info("Constructing fcnet {} {}".format(hiddens, activation)) layers = [] last_layer_size = inputs diff --git a/python/ray/rllib/models/pytorch/visionnet.py b/python/ray/rllib/models/pytorch/visionnet.py index 94ac8291d79af..e54c51897f2c3 100644 --- a/python/ray/rllib/models/pytorch/visionnet.py +++ b/python/ray/rllib/models/pytorch/visionnet.py @@ -18,11 +18,11 @@ def _build_layers(self, inputs, num_outputs, options): inputs (tuple): (channels, rows/height, cols/width) num_outputs (int): logits size """ - filters = options.get("conv_filters", [ + filters = options.get("conv_filters") or [ [16, [8, 8], 4], [32, [4, 4], 2], [512, [11, 11], 1], - ]) + ] layers = [] in_channels, in_size = inputs[0], inputs[1:] diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 805d2e9e5ebef..0638c4fc83c59 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -7,17 +7,20 @@ from ray.rllib.models.model import Model from ray.rllib.models.misc import get_activation_fn, flatten +from ray.rllib.utils.annotations import override class VisionNetwork(Model): """Generic vision network.""" - def _build_layers(self, inputs, num_outputs, options): + @override(Model) + def _build_layers_v2(self, input_dict, num_outputs, options): + inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: - filters = get_filter_config(options) + filters = _get_filter_config(inputs) - activation = get_activation_fn(options.get("conv_activation", "relu")) + activation = get_activation_fn(options.get("conv_activation")) with tf.name_scope("vision_net"): for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): @@ -46,7 +49,7 @@ def _build_layers(self, inputs, num_outputs, options): return flatten(fc2), flatten(fc1) -def get_filter_config(options): +def _get_filter_config(inputs): filters_84x84 = [ [16, [8, 8], 4], [32, [4, 4], 2], @@ -57,12 +60,15 @@ def get_filter_config(options): [32, [4, 4], 2], [256, [11, 11], 1], ] - dim = options.get("dim", 84) - if dim == 84: + shape = inputs.shape.as_list()[1:] + if len(shape) == 3 and shape[:2] == [84, 84]: return filters_84x84 - elif dim == 42: + elif len(shape) == 3 and shape[:2] == [42, 42]: return filters_42x42 else: raise ValueError( - "No default configuration for image size={}".format(dim) + - ", you must specify `conv_filters` manually as a model option.") + "No default configuration for obs input {}".format(inputs) + + ", you must specify `conv_filters` manually as a model option. " + "Default configurations are only available for inputs of size " + "[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want " + "to use a custom model or preprocessor.") diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index fc7fdb2488a33..b1e5ebe846ca6 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -4,6 +4,7 @@ import ray from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat @@ -15,6 +16,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer): gradient computations on the remote workers. """ + @override(PolicyOptimizer) def _init(self, grads_per_step=100): self.apply_timer = TimerStat() self.wait_timer = TimerStat() @@ -25,23 +27,29 @@ def _init(self, grads_per_step=100): raise ValueError( "Async optimizer requires at least 1 remote evaluator") + @override(PolicyOptimizer) def step(self): weights = ray.put(self.local_evaluator.get_weights()) - gradient_queue = [] + pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.remote_evaluators: e.set_weights.remote(weights) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + pending_gradients[future] = e num_gradients += 1 - # Note: can't use wait: https://github.com/ray-project/ray/issues/1128 - while gradient_queue: + while pending_gradients: with self.wait_timer: - fut, e = gradient_queue.pop(0) - gradient, info = ray.get(fut) + wait_results = ray.wait( + list(pending_gradients.keys()), num_returns=1) + ready_list = wait_results[0] + future = ready_list[0] + + gradient, info = ray.get(future) + e = pending_gradients.pop(future) + if "stats" in info: self.learner_stats = info["stats"] @@ -54,10 +62,12 @@ def step(self): if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote(self.local_evaluator.get_weights()) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + + pending_gradients[future] = e num_gradients += 1 + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 3ed5f37d390fe..582bb65396c14 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function +import collections import os import random import time @@ -15,9 +16,11 @@ from six.moves import queue import ray +from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.actors import TaskPool, create_colocated from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -27,113 +30,6 @@ LEARNER_QUEUE_MAX_SIZE = 16 -@ray.remote(num_cpus=0) -class ReplayActor(object): - """A replay buffer shard. - - Ray actors are single-threaded, so for scalability multiple replay actors - may be created to increase parallelism.""" - - def __init__(self, num_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps): - self.replay_starts = learning_starts // num_shards - self.buffer_size = buffer_size // num_shards - self.train_batch_size = train_batch_size - self.prioritized_replay_beta = prioritized_replay_beta - self.prioritized_replay_eps = prioritized_replay_eps - - self.replay_buffer = PrioritizedReplayBuffer( - self.buffer_size, alpha=prioritized_replay_alpha) - - # Metrics - self.add_batch_timer = TimerStat() - self.replay_timer = TimerStat() - self.update_priorities_timer = TimerStat() - - def get_host(self): - return os.uname()[1] - - def add_batch(self, batch): - PolicyOptimizer._check_not_multiagent(batch) - with self.add_batch_timer: - for row in batch.rows(): - self.replay_buffer.add(row["obs"], row["actions"], - row["rewards"], row["new_obs"], - row["dones"], row["weights"]) - - def replay(self): - with self.replay_timer: - if len(self.replay_buffer) < self.replay_starts: - return None - - (obses_t, actions, rewards, obses_tp1, dones, weights, - batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, beta=self.prioritized_replay_beta) - - batch = SampleBatch({ - "obs": obses_t, - "actions": actions, - "rewards": rewards, - "new_obs": obses_tp1, - "dones": dones, - "weights": weights, - "batch_indexes": batch_indexes - }) - return batch - - def update_priorities(self, batch_indexes, td_errors): - with self.update_priorities_timer: - new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) - self.replay_buffer.update_priorities(batch_indexes, new_priorities) - - def stats(self): - stat = { - "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "update_priorities_time_ms": round( - 1000 * self.update_priorities_timer.mean, 3), - } - stat.update(self.replay_buffer.stats()) - return stat - - -class LearnerThread(threading.Thread): - """Background thread that updates the local model from replay data. - - The learner thread communicates with the main thread through Queues. This - is needed since Ray operations can only be run on the main thread. In - addition, moving heavyweight gradient ops session runs off the main thread - improves overall throughput. - """ - - def __init__(self, local_evaluator): - threading.Thread.__init__(self) - self.learner_queue_size = WindowStat("size", 50) - self.local_evaluator = local_evaluator - self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) - self.outqueue = queue.Queue() - self.queue_timer = TimerStat() - self.grad_timer = TimerStat() - self.daemon = True - self.weights_updated = False - - def run(self): - while True: - self.step() - - def step(self): - with self.queue_timer: - ra, replay = self.inqueue.get() - if replay is not None: - with self.grad_timer: - td_error = self.local_evaluator.compute_apply(replay)[ - "td_error"] - self.outqueue.put((ra, replay, td_error, replay.count)) - self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True - - class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). @@ -144,6 +40,7 @@ class AsyncReplayOptimizer(PolicyOptimizer): "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" + @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, @@ -180,11 +77,12 @@ def _init(self, self.timers = { k: TimerStat() for k in [ - "put_weights", "get_samples", "enqueue", "sample_processing", + "put_weights", "get_samples", "sample_processing", "replay_processing", "update_priorities", "train", "sample" ] } self.num_weight_syncs = 0 + self.num_samples_dropped = 0 self.learning_started = False # Number of worker steps since the last weight update @@ -199,18 +97,9 @@ def _init(self, # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: - self.set_evaluators(self.remote_evaluators) - - # For https://github.com/ray-project/ray/issues/2541 only - def set_evaluators(self, remote_evaluators): - self.remote_evaluators = remote_evaluators - weights = self.local_evaluator.get_weights() - for ev in self.remote_evaluators: - ev.set_weights.remote(weights) - self.steps_since_update[ev] = 0 - for _ in range(SAMPLE_QUEUE_DEPTH): - self.sample_tasks.add(ev, ev.sample_with_count.remote()) + self._set_evaluators(self.remote_evaluators) + @override(PolicyOptimizer) def step(self): assert len(self.remote_evaluators) > 0 start = time.time() @@ -226,6 +115,53 @@ def step(self): self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps + @override(PolicyOptimizer) + def stop(self): + for r in self.replay_actors: + r.__ray_terminate__.remote() + self.learner.stopped = True + + @override(PolicyOptimizer) + def stats(self): + replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) + timing = { + "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) + for k in self.timers + } + timing["learner_grad_time_ms"] = round( + 1000 * self.learner.grad_timer.mean, 3) + timing["learner_dequeue_time_ms"] = round( + 1000 * self.learner.queue_timer.mean, 3) + stats = { + "sample_throughput": round(self.timers["sample"].mean_throughput, + 3), + "train_throughput": round(self.timers["train"].mean_throughput, 3), + "num_weight_syncs": self.num_weight_syncs, + "num_samples_dropped": self.num_samples_dropped, + "learner_queue": self.learner.learner_queue_size.stats(), + "replay_shard_0": replay_stats, + } + debug_stats = { + "timing_breakdown": timing, + "pending_sample_tasks": self.sample_tasks.count, + "pending_replay_tasks": self.replay_tasks.count, + } + if self.debug: + stats.update(debug_stats) + if self.learner.stats: + stats["learner"] = self.learner.stats + return dict(PolicyOptimizer.stats(self), **stats) + + # For https://github.com/ray-project/ray/issues/2541 only + def _set_evaluators(self, remote_evaluators): + self.remote_evaluators = remote_evaluators + weights = self.local_evaluator.get_weights() + for ev in self.remote_evaluators: + ev.set_weights.remote(weights) + self.steps_since_update[ev] = 0 + for _ in range(SAMPLE_QUEUE_DEPTH): + self.sample_tasks.add(ev, ev.sample_with_count.remote()) + def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None @@ -260,42 +196,148 @@ def _step(self): with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) - with self.timers["get_samples"]: - samples = ray.get(replay) - with self.timers["enqueue"]: - self.learner.inqueue.put((ra, samples)) + if self.learner.inqueue.full(): + self.num_samples_dropped += 1 + else: + with self.timers["get_samples"]: + samples = ray.get(replay) + # Defensive copy against plasma crashes, see #2610 #3452 + self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, replay, td_error, count = self.learner.outqueue.get() - ra.update_priorities.remote(replay["batch_indexes"], td_error) + ra, prio_dict, count = self.learner.outqueue.get() + ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps - def stats(self): - replay_stats = ray.get(self.replay_actors[0].stats.remote()) - timing = { - "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) - for k in self.timers - } - timing["learner_grad_time_ms"] = round( - 1000 * self.learner.grad_timer.mean, 3) - timing["learner_dequeue_time_ms"] = round( - 1000 * self.learner.queue_timer.mean, 3) - stats = { - "sample_throughput": round(self.timers["sample"].mean_throughput, - 3), - "train_throughput": round(self.timers["train"].mean_throughput, 3), - "num_weight_syncs": self.num_weight_syncs, - } - debug_stats = { - "replay_shard_0": replay_stats, - "timing_breakdown": timing, - "pending_sample_tasks": self.sample_tasks.count, - "pending_replay_tasks": self.replay_tasks.count, - "learner_queue": self.learner.learner_queue_size.stats(), + +@ray.remote(num_cpus=0) +class ReplayActor(object): + """A replay buffer shard. + + Ray actors are single-threaded, so for scalability multiple replay actors + may be created to increase parallelism.""" + + def __init__(self, num_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps): + self.replay_starts = learning_starts // num_shards + self.buffer_size = buffer_size // num_shards + self.train_batch_size = train_batch_size + self.prioritized_replay_beta = prioritized_replay_beta + self.prioritized_replay_eps = prioritized_replay_eps + + def new_buffer(): + return PrioritizedReplayBuffer( + self.buffer_size, alpha=prioritized_replay_alpha) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics + self.add_batch_timer = TimerStat() + self.replay_timer = TimerStat() + self.update_priorities_timer = TimerStat() + self.num_added = 0 + + def get_host(self): + return os.uname()[1] + + def add_batch(self, batch): + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) + with self.add_batch_timer: + for policy_id, s in batch.policy_batches.items(): + for row in s.rows(): + self.replay_buffers[policy_id].add( + row["obs"], row["actions"], row["rewards"], + row["new_obs"], row["dones"], row["weights"]) + self.num_added += batch.count + + def replay(self): + if self.num_added < self.replay_starts: + return None + + with self.replay_timer: + samples = {} + for policy_id, replay_buffer in self.replay_buffers.items(): + (obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes) = replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta) + samples[policy_id] = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) + return MultiAgentBatch(samples, self.train_batch_size) + + def update_priorities(self, prio_dict): + with self.update_priorities_timer: + for policy_id, (batch_indexes, td_errors) in prio_dict.items(): + new_priorities = ( + np.abs(td_errors) + self.prioritized_replay_eps) + self.replay_buffers[policy_id].update_priorities( + batch_indexes, new_priorities) + + def stats(self, debug=False): + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "update_priorities_time_ms": round( + 1000 * self.update_priorities_timer.mean, 3), } - if self.debug: - stats.update(debug_stats) - return dict(PolicyOptimizer.stats(self), **stats) + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update({ + "policy_{}".format(policy_id): replay_buffer.stats(debug=debug) + }) + return stat + + +class LearnerThread(threading.Thread): + """Background thread that updates the local model from replay data. + + The learner thread communicates with the main thread through Queues. This + is needed since Ray operations can only be run on the main thread. In + addition, moving heavyweight gradient ops session runs off the main thread + improves overall throughput. + """ + + def __init__(self, local_evaluator): + threading.Thread.__init__(self) + self.learner_queue_size = WindowStat("size", 50) + self.local_evaluator = local_evaluator + self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) + self.outqueue = queue.Queue() + self.queue_timer = TimerStat() + self.grad_timer = TimerStat() + self.daemon = True + self.weights_updated = False + self.stopped = False + self.stats = {} + + def run(self): + while not self.stopped: + self.step() + + def step(self): + with self.queue_timer: + ra, replay = self.inqueue.get() + if replay is not None: + prio_dict = {} + with self.grad_timer: + grad_out = self.local_evaluator.compute_apply(replay) + for pid, info in grad_out.items(): + prio_dict[pid] = ( + replay.policy_batches[pid]["batch_indexes"], + info["td_error"]) + if "stats" in info: + self.stats[pid] = info["stats"] + self.outqueue.put((ra, prio_dict, replay.count)) + self.learner_queue_size.push(self.inqueue.qsize()) + self.weights_updated = True diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 3b6bb861b4824..ad0d86dfc2c36 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -6,58 +6,26 @@ from __future__ import division from __future__ import print_function +import logging +import numpy as np +import random import time import threading from six.moves import queue import ray +from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.actors import TaskPool +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat -SAMPLE_QUEUE_DEPTH = 2 -LEARNER_QUEUE_MAX_SIZE = 16 - - -class LearnerThread(threading.Thread): - """Background thread that updates the local model from sample trajectories. - - The learner thread communicates with the main thread through Queues. This - is needed since Ray operations can only be run on the main thread. In - addition, moving heavyweight gradient ops session runs off the main thread - improves overall throughput. - """ +logger = logging.getLogger(__name__) - def __init__(self, local_evaluator): - threading.Thread.__init__(self) - self.learner_queue_size = WindowStat("size", 50) - self.local_evaluator = local_evaluator - self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) - self.outqueue = queue.Queue() - self.queue_timer = TimerStat() - self.grad_timer = TimerStat() - self.daemon = True - self.weights_updated = 0 - self.stats = {} - - def run(self): - while True: - self.step() - - def step(self): - with self.queue_timer: - ra, batch = self.inqueue.get() - - if batch is not None: - with self.grad_timer: - fetches = self.local_evaluator.compute_apply(batch) - self.weights_updated += 1 - if "stats" in fetches: - self.stats = fetches["stats"] - self.outqueue.put(batch.count) - self.learner_queue_size.push(self.inqueue.qsize()) +LEARNER_QUEUE_MAX_SIZE = 16 +NUM_DATA_LOAD_THREADS = 16 class AsyncSamplesOptimizer(PolicyOptimizer): @@ -67,24 +35,49 @@ class AsyncSamplesOptimizer(PolicyOptimizer): and remote evaluators (IMPALA actors). """ - def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): - - self.debug = debug + @override(PolicyOptimizer) + def _init(self, + train_batch_size=500, + sample_batch_size=50, + num_envs_per_worker=1, + num_gpus=0, + lr=0.0005, + grad_clip=40, + replay_buffer_num_slots=0, + replay_proportion=0.0, + num_parallel_data_loaders=1, + max_sample_requests_in_flight_per_worker=2, + broadcast_interval=1): self.learning_started = False self.train_batch_size = train_batch_size + self.sample_batch_size = sample_batch_size + self.broadcast_interval = broadcast_interval - self.learner = LearnerThread(self.local_evaluator) + if num_gpus > 1 or num_parallel_data_loaders > 1: + logger.info( + "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( + num_gpus, num_parallel_data_loaders)) + if train_batch_size // max(1, num_gpus) % ( + sample_batch_size // num_envs_per_worker) != 0: + raise ValueError( + "Sample batches must evenly divide across GPUs.") + self.learner = TFMultiGPULearner( + self.local_evaluator, + lr=lr, + num_gpus=num_gpus, + train_batch_size=train_batch_size, + grad_clip=grad_clip, + num_parallel_data_loaders=num_parallel_data_loaders) + else: + self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 # Stats - self.timers = { - k: TimerStat() - for k in - ["put_weights", "enqueue", "sample_processing", "train", "sample"] - } + self.timers = {k: TimerStat() for k in ["train", "sample"]} self.num_weight_syncs = 0 + self.num_replayed = 0 self.learning_started = False # Kick off async background sampling @@ -92,11 +85,20 @@ def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) - for _ in range(SAMPLE_QUEUE_DEPTH): + for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] + if replay_proportion: + assert replay_buffer_num_slots > 0 + assert (replay_buffer_num_slots * sample_batch_size > + train_batch_size) + self.replay_proportion = replay_proportion + self.replay_buffer_num_slots = replay_buffer_num_slots + self.replay_batches = [] + + @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() start = time.time() @@ -112,41 +114,11 @@ def step(self): self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps - def _step(self): - sample_timesteps, train_timesteps = 0, 0 - weights = None - - with self.timers["sample_processing"]: - for ev, sample_batch in self.sample_tasks.completed_prefetch(): - sample_batch = ray.get(sample_batch) - sample_timesteps += sample_batch.count - self.batch_buffer.append(sample_batch) - if sum(b.count - for b in self.batch_buffer) >= self.train_batch_size: - train_batch = self.batch_buffer[0].concat_samples( - self.batch_buffer) - with self.timers["enqueue"]: - self.learner.inqueue.put((ev, train_batch)) - self.batch_buffer = [] - - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors - if weights is None or self.learner.weights_updated: - self.learner.weights_updated = False - with self.timers["put_weights"]: - weights = ray.put(self.local_evaluator.get_weights()) - ev.set_weights.remote(weights) - self.num_weight_syncs += 1 - - # Kick off another sample request - self.sample_tasks.add(ev, ev.sample.remote()) - - while not self.learner.outqueue.empty(): - count = self.learner.outqueue.get() - train_timesteps += count - - return sample_timesteps, train_timesteps + @override(PolicyOptimizer) + def stop(self): + self.learner.stopped = True + @override(PolicyOptimizer) def stats(self): timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) @@ -154,6 +126,10 @@ def stats(self): } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) + timing["learner_load_time_ms"] = round( + 1000 * self.learner.load_timer.mean, 3) + timing["learner_load_wait_time_ms"] = round( + 1000 * self.learner.load_wait_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { @@ -161,14 +137,229 @@ def stats(self): 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, - } - debug_stats = { + "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, - "pending_sample_tasks": self.sample_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } - if self.debug: - stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) + + def _step(self): + sample_timesteps, train_timesteps = 0, 0 + num_sent = 0 + weights = None + + for ev, sample_batch in self._augment_with_replay( + self.sample_tasks.completed_prefetch()): + self.batch_buffer.append(sample_batch) + if sum(b.count + for b in self.batch_buffer) >= self.train_batch_size: + train_batch = self.batch_buffer[0].concat_samples( + self.batch_buffer) + self.learner.inqueue.put(train_batch) + self.batch_buffer = [] + + # If the batch was replayed, skip the update below. + if ev is None: + continue + + sample_timesteps += sample_batch.count + + # Put in replay buffer if enabled + if self.replay_buffer_num_slots > 0: + self.replay_batches.append(sample_batch) + if len(self.replay_batches) > self.replay_buffer_num_slots: + self.replay_batches.pop(0) + + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors + if weights is None or (self.learner.weights_updated + and num_sent >= self.broadcast_interval): + self.learner.weights_updated = False + weights = ray.put(self.local_evaluator.get_weights()) + num_sent = 0 + ev.set_weights.remote(weights) + self.num_weight_syncs += 1 + num_sent += 1 + + # Kick off another sample request + self.sample_tasks.add(ev, ev.sample.remote()) + + while not self.learner.outqueue.empty(): + count = self.learner.outqueue.get() + train_timesteps += count + + return sample_timesteps, train_timesteps + + def _augment_with_replay(self, sample_futures): + def can_replay(): + num_needed = int( + np.ceil(self.train_batch_size / self.sample_batch_size)) + return len(self.replay_batches) > num_needed + + for ev, sample_batch in sample_futures: + sample_batch = ray.get(sample_batch) + yield ev, sample_batch + + if can_replay(): + f = self.replay_proportion + while random.random() < f: + f -= 1 + replay_batch = random.choice(self.replay_batches) + self.num_replayed += replay_batch.count + yield None, replay_batch + + +class LearnerThread(threading.Thread): + """Background thread that updates the local model from sample trajectories. + + The learner thread communicates with the main thread through Queues. This + is needed since Ray operations can only be run on the main thread. In + addition, moving heavyweight gradient ops session runs off the main thread + improves overall throughput. + """ + + def __init__(self, local_evaluator): + threading.Thread.__init__(self) + self.learner_queue_size = WindowStat("size", 50) + self.local_evaluator = local_evaluator + self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) + self.outqueue = queue.Queue() + self.queue_timer = TimerStat() + self.grad_timer = TimerStat() + self.load_timer = TimerStat() + self.load_wait_timer = TimerStat() + self.daemon = True + self.weights_updated = False + self.stats = {} + self.stopped = False + + def run(self): + while not self.stopped: + self.step() + + def step(self): + with self.queue_timer: + batch = self.inqueue.get() + + with self.grad_timer: + fetches = self.local_evaluator.compute_apply(batch) + self.weights_updated = True + self.stats = fetches.get("stats", {}) + + self.outqueue.put(batch.count) + self.learner_queue_size.push(self.inqueue.qsize()) + + +class TFMultiGPULearner(LearnerThread): + """Learner that can use multiple GPUs and parallel loading.""" + + def __init__(self, + local_evaluator, + num_gpus=1, + lr=0.0005, + train_batch_size=500, + grad_clip=40, + num_parallel_data_loaders=1): + # Multi-GPU requires TensorFlow to function. + import tensorflow as tf + + LearnerThread.__init__(self, local_evaluator) + self.lr = lr + self.train_batch_size = train_batch_size + if not num_gpus: + self.devices = ["/cpu:0"] + else: + self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)] + logger.info("TFMultiGPULearner devices {}".format(self.devices)) + assert self.train_batch_size % len(self.devices) == 0 + assert self.train_batch_size >= len(self.devices), "batch too small" + self.policy = self.local_evaluator.policy_map["default"] + + # per-GPU graph copies created below must share vars with the policy + # reuse is set to AUTO_REUSE because Adam nodes are created after + # all of the device copies are created. + self.par_opt = [] + with self.local_evaluator.tf_sess.graph.as_default(): + with self.local_evaluator.tf_sess.as_default(): + with tf.variable_scope("default", reuse=tf.AUTO_REUSE): + if self.policy._state_inputs: + rnn_inputs = self.policy._state_inputs + [ + self.policy._seq_lens + ] + else: + rnn_inputs = [] + adam = tf.train.AdamOptimizer(self.lr) + for _ in range(num_parallel_data_loaders): + self.par_opt.append( + LocalSyncParallelOptimizer( + adam, + self.devices, + [v for _, v in self.policy._loss_inputs], + rnn_inputs, + 999999, # it will get rounded down + self.policy.copy, + grad_norm_clipping=grad_clip)) + + self.sess = self.local_evaluator.tf_sess + self.sess.run(tf.global_variables_initializer()) + + self.idle_optimizers = queue.Queue() + self.ready_optimizers = queue.Queue() + for opt in self.par_opt: + self.idle_optimizers.put(opt) + for i in range(NUM_DATA_LOAD_THREADS): + self.loader_thread = _LoaderThread(self, share_stats=(i == 0)) + self.loader_thread.start() + + @override(LearnerThread) + def step(self): + assert self.loader_thread.is_alive() + with self.load_wait_timer: + opt = self.ready_optimizers.get() + + with self.grad_timer: + fetches = opt.optimize(self.sess, 0) + self.weights_updated = True + self.stats = fetches.get("stats", {}) + + self.idle_optimizers.put(opt) + self.outqueue.put(self.train_batch_size) + self.learner_queue_size.push(self.inqueue.qsize()) + + +class _LoaderThread(threading.Thread): + def __init__(self, learner, share_stats): + threading.Thread.__init__(self) + self.learner = learner + self.daemon = True + if share_stats: + self.queue_timer = learner.queue_timer + self.load_timer = learner.load_timer + else: + self.queue_timer = TimerStat() + self.load_timer = TimerStat() + + def run(self): + while True: + self._step() + + def _step(self): + s = self.learner + with self.queue_timer: + batch = s.inqueue.get() + + opt = s.idle_optimizers.get() + + with self.load_timer: + tuples = s.policy._get_loss_inputs_dict(batch) + data_keys = [ph for _, ph in s.policy._loss_inputs] + if s.policy._state_inputs: + state_keys = s.policy._state_inputs + [s.policy._seq_lens] + else: + state_keys = [] + opt.load_data(s.sess, [tuples[k] for k in data_keys], + [tuples[k] for k in state_keys]) + + s.ready_optimizers.put(opt) diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 7233e37e93802..c548b20cc022d 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -3,12 +3,15 @@ from __future__ import print_function from collections import namedtuple +import logging import tensorflow as tf # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" +logger = logging.getLogger(__name__) + class LocalSyncParallelOptimizer(object): """Optimizer that runs in parallel across multiple local devices. @@ -36,13 +39,13 @@ class LocalSyncParallelOptimizer(object): to define the per-device loss ops. rnn_inputs: Extra input placeholders for RNN inputs. These will have shape [BATCH_SIZE // MAX_SEQ_LEN, ...]. - per_device_batch_size: Number of tuples to optimize over at a time per - device. In each call to `optimize()`, + max_per_device_batch_size: Number of tuples to optimize over at a time + per device. In each call to `optimize()`, `len(devices) * per_device_batch_size` tuples of data will be - processed. + processed. If this is larger than the total data size, it will be + clipped. build_graph: Function that takes the specified inputs and returns a TF Policy Graph instance. - logdir: Directory to place debugging output in. grad_norm_clipping: None or int stdev to clip grad norms by """ @@ -51,26 +54,29 @@ def __init__(self, devices, input_placeholders, rnn_inputs, - per_device_batch_size, + max_per_device_batch_size, build_graph, - logdir, grad_norm_clipping=None): - # TODO(rliaw): remove logdir self.optimizer = optimizer self.devices = devices - self.batch_size = per_device_batch_size * len(devices) - self.per_device_batch_size = per_device_batch_size + self.max_per_device_batch_size = max_per_device_batch_size self.loss_inputs = input_placeholders + rnn_inputs self.build_graph = build_graph - self.logdir = logdir # First initialize the shared loss network with tf.name_scope(TOWER_SCOPE_NAME): self._shared_loss = build_graph(self.loss_inputs) + shared_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) # Then setup the per-device loss graphs that use the shared weights self._batch_index = tf.placeholder(tf.int32, name="batch_index") + # Dynamic batch size, which may be shrunk if there isn't enough data + self._per_device_batch_size = tf.placeholder( + tf.int32, name="per_device_batch_size") + self._loaded_per_device_batch_size = max_per_device_batch_size + # When loading RNN input, we dynamically determine the max seq len self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len") self._loaded_max_seq_len = 1 @@ -88,10 +94,26 @@ def __init__(self, avg = average_gradients([t.grads for t in self._towers]) if grad_norm_clipping: + clipped = [] + for grad, _ in avg: + clipped.append(grad) + clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping) for i, (grad, var) in enumerate(avg): - if grad is not None: - avg[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var) - self._train_op = self.optimizer.apply_gradients(avg) + avg[i] = (clipped[i], var) + + # gather update ops for any batch norm layers. TODO(ekl) here we will + # use all the ops found which won't work for DQN / DDPG, but those + # aren't supported with multi-gpu right now anyways. + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + for op in shared_ops: + self._update_ops.remove(op) # only care about tower update ops + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + + with tf.control_dependencies(self._update_ops): + self._train_op = self.optimizer.apply_gradients(avg) def load_data(self, sess, inputs, state_inputs): """Bulk loads the specified inputs into device memory. @@ -117,44 +139,64 @@ def load_data(self, sess, inputs, state_inputs): assert len(self.loss_inputs) == len(inputs + state_inputs), \ (self.loss_inputs, inputs, state_inputs) - # The RNN truncation case is more complicated + # Let's suppose we have the following input data, and 2 devices: + # 1 2 3 4 5 6 7 <- state inputs shape + # A A A B B B C C C D D D E E E F F F G G G <- inputs shape + # The data is truncated and split across devices as follows: + # |---| seq len = 3 + # |---------------------------------| seq batch size = 6 seqs + # |----------------| per device batch size = 9 tuples + if len(state_inputs) > 0: + smallest_array = state_inputs[0] seq_len = len(inputs[0]) // len(state_inputs[0]) self._loaded_max_seq_len = seq_len - assert len(state_inputs[0]) * seq_len == len(inputs[0]) - # Make sure the shorter state inputs arrays are evenly divisible + else: + smallest_array = inputs[0] + self._loaded_max_seq_len = 1 + + seq_batch_size = (self.max_per_device_batch_size // + self._loaded_max_seq_len * len(self.devices)) + if len(smallest_array) < seq_batch_size: + # Dynamically shrink the batch size if insufficient data + seq_batch_size = make_divisible_by( + len(smallest_array), len(self.devices)) + if seq_batch_size < len(self.devices): + raise ValueError("Must load at least 1 tuple sequence per device, " + "got only {} total.".format(len(smallest_array))) + self._loaded_per_device_batch_size = ( + seq_batch_size // len(self.devices) * self._loaded_max_seq_len) + + if len(state_inputs) > 0: + # First truncate the RNN state arrays to the seq_batch_size state_inputs = [ - make_divisible_by(arr, self.batch_size) for arr in state_inputs + make_divisible_by(arr, seq_batch_size) for arr in state_inputs ] # Then truncate the data inputs to match inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs] - assert len(state_inputs[0]) * seq_len == len(inputs[0]) - assert len(state_inputs[0]) % self.batch_size == 0 + assert len(state_inputs[0]) * seq_len == len(inputs[0]), \ + (len(state_inputs[0]), seq_batch_size, seq_len, len(inputs[0])) for ph, arr in zip(self.loss_inputs, inputs + state_inputs): feed_dict[ph] = arr truncated_len = len(inputs[0]) else: for ph, arr in zip(self.loss_inputs, inputs + state_inputs): - truncated_arr = make_divisible_by(arr, self.batch_size) + truncated_arr = make_divisible_by(arr, seq_batch_size) feed_dict[ph] = truncated_arr truncated_len = len(truncated_arr) sess.run([t.init_op for t in self._towers], feed_dict=feed_dict) tuples_per_device = truncated_len / len(self.devices) - assert tuples_per_device > 0, \ - "Too few tuples per batch, trying increasing the training " \ - "batch size or decreasing the sgd batch size. Tried to split up " \ - "{} rows {}-ways in batches of {} (total across devices).".format( - len(arr), len(self.devices), self.batch_size) - assert tuples_per_device % self.per_device_batch_size == 0 + assert tuples_per_device > 0, "No data loaded?" + assert tuples_per_device % self._loaded_per_device_batch_size == 0 return tuples_per_device def optimize(self, sess, batch_index): """Run a single step of SGD. Runs a SGD step over a slice of the preloaded batch with size given by - self.per_device_batch_size and offset given by the batch_index + self._loaded_per_device_batch_size and offset given by the batch_index argument. Updates shared model weights based on the averaged per-device @@ -164,13 +206,14 @@ def optimize(self, sess, batch_index): sess: TensorFlow session. batch_index: Offset into the preloaded data. This value must be between `0` and `tuples_per_device`. The amount of data to - process is always fixed to `per_device_batch_size`. + process is at most `max_per_device_batch_size`. Returns: The outputs of extra_ops evaluated over the batch. """ feed_dict = { self._batch_index: batch_index, + self._per_device_batch_size: self._loaded_per_device_batch_size, self._max_seq_len: self._loaded_max_seq_len, } for tower in self._towers: @@ -213,7 +256,7 @@ def _setup_device(self, device, device_input_placeholders, num_data_in): current_batch, ([self._batch_index // scale * granularity] + [0] * len(ph.shape[1:])), - ([self.per_device_batch_size // scale * granularity] + + ([self._per_device_batch_size // scale * granularity] + [-1] * len(ph.shape[1:]))) current_slice.set_shape(ph.shape) device_input_slices.append(current_slice) @@ -229,8 +272,10 @@ def _setup_device(self, device, device_input_placeholders, num_data_in): Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"]) -def make_divisible_by(array, n): - return array[0:array.shape[0] - array.shape[0] % n] +def make_divisible_by(a, n): + if type(a) is int: + return a - a % n + return a[0:a.shape[0] - a.shape[0] % n] def average_gradients(tower_grads): diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index e474570363937..5ca29f68c8861 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -2,16 +2,22 @@ from __future__ import division from __future__ import print_function +import logging +import math import numpy as np from collections import defaultdict -import os import tensorflow as tf import ray from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat +from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch + +logger = logging.getLogger(__name__) class LocalMultiGPUOptimizer(PolicyOptimizer): @@ -30,6 +36,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): may result in unexpected behavior. """ + @override(PolicyOptimizer) def _init(self, sgd_batch_size=128, num_sgd_iter=10, @@ -42,7 +49,9 @@ def _init(self, if not num_gpus: self.devices = ["/cpu:0"] else: - self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)] + self.devices = [ + "/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus))) + ] self.batch_size = int(sgd_batch_size / len(self.devices)) * len( self.devices) assert self.batch_size % len(self.devices) == 0 @@ -54,39 +63,40 @@ def _init(self, self.update_weights_timer = TimerStat() self.standardize_fields = standardize_fields - print("LocalMultiGPUOptimizer devices", self.devices) + logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices)) - if set(self.local_evaluator.policy_map.keys()) != {"default"}: - raise ValueError( - "Multi-agent is not supported with multi-GPU. Try using the " - "simple optimizer instead.") - self.policy = self.local_evaluator.policy_map["default"] - if not isinstance(self.policy, TFPolicyGraph): - raise ValueError( - "Only TF policies are supported with multi-GPU. Try using the " - "simple optimizer instead.") + self.policies = self.local_evaluator.policy_map + for policy_id, policy in self.policies.items(): + if not isinstance(policy, TFPolicyGraph): + raise ValueError( + "Only TF policies are supported with multi-GPU. Try using " + "the simple optimizer instead.") # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. + self.optimizers = {} with self.local_evaluator.tf_sess.graph.as_default(): with self.local_evaluator.tf_sess.as_default(): - with tf.variable_scope("default", reuse=tf.AUTO_REUSE): - if self.policy._state_inputs: - rnn_inputs = self.policy._state_inputs + [ - self.policy._seq_lens - ] - else: - rnn_inputs = [] - self.par_opt = LocalSyncParallelOptimizer( - self.policy.optimizer(), self.devices, - [v for _, v in self.policy.loss_inputs()], rnn_inputs, - self.per_device_batch_size, self.policy.copy, - os.getcwd()) + for policy_id, policy in self.policies.items(): + with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): + if policy._state_inputs: + rnn_inputs = policy._state_inputs + [ + policy._seq_lens + ] + else: + rnn_inputs = [] + self.optimizers[policy_id] = ( + LocalSyncParallelOptimizer( + policy._optimizer, self.devices, + [v + for _, v in policy._loss_inputs], rnn_inputs, + self.per_device_batch_size, policy.copy)) self.sess = self.local_evaluator.tf_sess self.sess.run(tf.global_variables_initializer()) + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -102,48 +112,64 @@ def step(self): self.train_batch_size) else: samples = self.local_evaluator.sample() - self._check_not_multiagent(samples) - - for field in self.standardize_fields: - value = samples[field] - standardized = (value - value.mean()) / max(1e-4, value.std()) - samples[field] = standardized - - # Important: don't shuffle RNN sequence elements - if not self.policy._state_inputs: - samples.shuffle() - + # Handle everything as if multiagent + if isinstance(samples, SampleBatch): + samples = MultiAgentBatch({ + DEFAULT_POLICY_ID: samples + }, samples.count) + + for _, batch in samples.policy_batches.items(): + for field in self.standardize_fields: + value = batch[field] + standardized = (value - value.mean()) / max(1e-4, value.std()) + batch[field] = standardized + + for policy_id, policy in self.policies.items(): + # Important: don't shuffle RNN sequence elements + if (policy_id in samples.policy_batches + and not policy._state_inputs): + samples.policy_batches[policy_id].shuffle() + + num_loaded_tuples = {} with self.load_timer: - tuples = self.policy._get_loss_inputs_dict(samples) - data_keys = [ph for _, ph in self.policy.loss_inputs()] - if self.policy._state_inputs: - state_keys = ( - self.policy._state_inputs + [self.policy._seq_lens]) - else: - state_keys = [] - tuples_per_device = self.par_opt.load_data( - self.sess, [tuples[k] for k in data_keys], - [tuples[k] for k in state_keys]) - + for policy_id, batch in samples.policy_batches.items(): + policy = self.policies[policy_id] + tuples = policy._get_loss_inputs_dict(batch) + data_keys = [ph for _, ph in policy._loss_inputs] + if policy._state_inputs: + state_keys = policy._state_inputs + [policy._seq_lens] + else: + state_keys = [] + num_loaded_tuples[policy_id] = ( + self.optimizers[policy_id].load_data( + self.sess, [tuples[k] for k in data_keys], + [tuples[k] for k in state_keys])) + + fetches = {} with self.grad_timer: - num_batches = ( - int(tuples_per_device) // int(self.per_device_batch_size)) - print("== sgd epochs ==") - for i in range(self.num_sgd_iter): - iter_extra_fetches = defaultdict(list) - permutation = np.random.permutation(num_batches) - for batch_index in range(num_batches): - batch_fetches = self.par_opt.optimize( - self.sess, - permutation[batch_index] * self.per_device_batch_size) - for k, v in batch_fetches.items(): - iter_extra_fetches[k].append(v) - print(i, _averaged(iter_extra_fetches)) + for policy_id, tuples_per_device in num_loaded_tuples.items(): + optimizer = self.optimizers[policy_id] + num_batches = ( + int(tuples_per_device) // int(self.per_device_batch_size)) + logger.debug("== sgd epochs for {} ==".format(policy_id)) + for i in range(self.num_sgd_iter): + iter_extra_fetches = defaultdict(list) + permutation = np.random.permutation(num_batches) + for batch_index in range(num_batches): + batch_fetches = optimizer.optimize( + self.sess, permutation[batch_index] * + self.per_device_batch_size) + for k, v in batch_fetches.items(): + iter_extra_fetches[k].append(v) + logger.debug("{} {}".format(i, + _averaged(iter_extra_fetches))) + fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count - return _averaged(iter_extra_fetches) + return fetches + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 21fcf5f0b7a77..a0cc085eec898 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -2,10 +2,13 @@ from __future__ import division from __future__ import print_function +import logging + import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes -from ray.rllib.evaluation.sample_batch import MultiAgentBatch + +logger = logging.getLogger(__name__) class PolicyOptimizer(object): @@ -53,10 +56,13 @@ def __init__(self, local_evaluator, remote_evaluators=None, config=None): self.num_steps_trained = 0 self.num_steps_sampled = 0 + logger.debug("Created policy optimizer with {}: {}".format( + config, self)) + def _init(self): """Subclasses should prefer overriding this instead of __init__.""" - pass + raise NotImplementedError def step(self): """Takes a logical optimization step. @@ -79,18 +85,42 @@ def stats(self): "num_steps_sampled": self.num_steps_sampled, } - def collect_metrics(self, min_history=100): + def save(self): + """Returns a serializable object representing the optimizer state.""" + + return [self.num_steps_trained, self.num_steps_sampled] + + def restore(self, data): + """Restores optimizer state from the given data object.""" + + self.num_steps_trained = data[0] + self.num_steps_sampled = data[1] + + def stop(self): + """Release any resources used by this optimizer.""" + pass + + def collect_metrics(self, + timeout_seconds, + min_history=100, + selected_evaluators=None): """Returns evaluator and optimizer stats. Arguments: + timeout_seconds (int): Max wait time for a evaluator before + dropping its results. This usually indicates a hung evaluator. min_history (int): Min history length to smooth results over. + selected_evaluators (list): Override the list of remote evaluators + to collect metrics from. Returns: res (dict): A training result dict from evaluator metrics with `info` replaced with stats from self. """ - episodes = collect_episodes(self.local_evaluator, - self.remote_evaluators) + episodes, num_dropped = collect_episodes( + self.local_evaluator, + selected_evaluators or self.remote_evaluators, + timeout_seconds=timeout_seconds) orig_episodes = list(episodes) missing = min_history - len(episodes) if missing > 0: @@ -98,21 +128,10 @@ def collect_metrics(self, min_history=100): assert len(episodes) <= min_history self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-min_history:] - res = summarize_episodes(episodes, orig_episodes) + res = summarize_episodes(episodes, orig_episodes, num_dropped) res.update(info=self.stats()) return res - def save(self): - """Returns a serializable object representing the optimizer state.""" - - return [self.num_steps_trained, self.num_steps_sampled] - - def restore(self, data): - """Restores optimizer state from the given data object.""" - - self.num_steps_trained = data[0] - self.num_steps_sampled = data[1] - def foreach_evaluator(self, func): """Apply the given function to each evaluator instance.""" @@ -134,12 +153,6 @@ def foreach_evaluator_with_index(self, func): ]) return local_result + remote_results - @staticmethod - def _check_not_multiagent(sample_batch): - if isinstance(sample_batch, MultiAgentBatch): - raise NotImplementedError( - "This optimizer does not support multi-agent yet.") - @classmethod def make(cls, env_creator, diff --git a/python/ray/rllib/optimizers/replay_buffer.py b/python/ray/rllib/optimizers/replay_buffer.py index 77d954345668f..cd5ec732848e0 100644 --- a/python/ray/rllib/optimizers/replay_buffer.py +++ b/python/ray/rllib/optimizers/replay_buffer.py @@ -93,14 +93,15 @@ def sample(self, batch_size): self._num_sampled += batch_size return self._encode_sample(idxes) - def stats(self): + def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } - data.update(self._evicted_hit_stats.stats()) + if debug: + data.update(self._evicted_hit_stats.stats()) return data @@ -233,7 +234,8 @@ def update_priorities(self, idxes, priorities): self._max_priority = max(self._max_priority, priority) - def stats(self): - parent = ReplayBuffer.stats(self) - parent.update(self._prio_change_stats.stats()) + def stats(self, debug=False): + parent = ReplayBuffer.stats(self, debug) + if debug: + parent.update(self._prio_change_stats.stats()) return parent diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index 73df006014679..f2a42a08302a5 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -11,6 +11,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.compression import pack_if_needed from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat @@ -24,6 +25,7 @@ class SyncReplayOptimizer(PolicyOptimizer): "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" + @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, @@ -53,6 +55,7 @@ def _init(self, self.replay_timer = TimerStat() self.grad_timer = TimerStat() self.throughput = RunningStat() + self.learner_stats = {} # Set up replay buffer if prioritized_replay: @@ -69,6 +72,7 @@ def new_buffer(): assert buffer_size >= self.replay_starts + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -105,12 +109,29 @@ def step(self): self.num_steps_sampled += batch.count + @override(PolicyOptimizer) + def stats(self): + return dict( + PolicyOptimizer.stats(self), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": round(1000 * self.update_weights_timer.mean, + 3), + "opt_peak_throughput": round(self.grad_timer.mean_throughput, + 3), + "opt_samples": round(self.grad_timer.mean_units_processed, 3), + "learner": self.learner_stats, + }) + def _optimize(self): samples = self._replay() with self.grad_timer: info_dict = self.local_evaluator.compute_apply(samples) for policy_id, info in info_dict.items(): + if "stats" in info: + self.learner_stats[policy_id] = info["stats"] replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] @@ -148,16 +169,3 @@ def _replay(self): "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size) - - def stats(self): - return dict( - PolicyOptimizer.stats(self), **{ - "sample_time_ms": round(1000 * self.sample_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "grad_time_ms": round(1000 * self.grad_timer.mean, 3), - "update_time_ms": round(1000 * self.update_weights_timer.mean, - 3), - "opt_peak_throughput": round(self.grad_timer.mean_throughput, - 3), - "opt_samples": round(self.grad_timer.mean_units_processed, 3), - }) diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index 20922ff54036a..b78e3ed01d70e 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -3,11 +3,15 @@ from __future__ import print_function import ray +import logging from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat +logger = logging.getLogger(__name__) + class SyncSamplesOptimizer(PolicyOptimizer): """A simple synchronous RL optimizer. @@ -17,6 +21,7 @@ class SyncSamplesOptimizer(PolicyOptimizer): model weights are then broadcast to all remote evaluators. """ + @override(PolicyOptimizer) def _init(self, num_sgd_iter=1, train_batch_size=1): self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() @@ -26,6 +31,7 @@ def _init(self, num_sgd_iter=1, train_batch_size=1): self.train_batch_size = train_batch_size self.learner_stats = {} + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -52,13 +58,14 @@ def step(self): if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: - print(i, fetches) + logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 0e33e3d6ced61..bee5c5eb2ae12 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -12,8 +12,6 @@ import gym import ray from ray.rllib.agents.agent import get_agent_class -from ray.rllib.agents.dqn.common.wrappers import wrap_dqn -from ray.rllib.models import ModelCatalog EXAMPLE_USAGE = """ Example Usage via RLlib CLI: @@ -54,7 +52,7 @@ def create_parser(parser_creator=None): const=True, help="Surpress rendering of the environment.") parser.add_argument( - "--steps", default=None, help="Number of steps to roll out.") + "--steps", default=10000, help="Number of steps to roll out.") parser.add_argument("--out", default=None, help="Output filename.") parser.add_argument( "--config", @@ -66,30 +64,38 @@ def create_parser(parser_creator=None): def run(args, parser): - if not args.config: + config = args.config + if not config: # Load configuration from file config_dir = os.path.dirname(args.checkpoint) config_path = os.path.join(config_dir, "params.json") + if not os.path.exists(config_path): + config_path = os.path.join(config_dir, "../params.json") + if not os.path.exists(config_path): + raise ValueError( + "Could not find params.json in either the checkpoint dir or " + "its parent directory.") with open(config_path) as f: - args.config = json.load(f) + config = json.load(f) + if "num_workers" in config: + config["num_workers"] = min(2, config["num_workers"]) if not args.env: - if not args.config.get("env"): + if not config.get("env"): parser.error("the following arguments are required: --env") - args.env = args.config.get("env") + args.env = config.get("env") ray.init() cls = get_agent_class(args.run) - agent = cls(env=args.env, config=args.config) + agent = cls(env=args.env, config=config) agent.restore(args.checkpoint) num_steps = int(args.steps) - if args.run == "DQN": - env = gym.make(args.env) - env = wrap_dqn(env, args.config.get("model", {})) + if hasattr(agent, "local_evaluator"): + env = agent.local_evaluator.env else: - env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env)) + env = gym.make(args.env) if args.out is not None: rollouts = [] steps = 0 diff --git a/python/ray/rllib/scripts.py b/python/ray/rllib/scripts.py index cc48b83cf3341..88d5d56292b13 100644 --- a/python/ray/rllib/scripts.py +++ b/python/ray/rllib/scripts.py @@ -14,7 +14,7 @@ rllib train --run DQN --env CartPole-v0 Example usage for rollout: - rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN + rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN """ diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index e3dc1e782535d..efa1aba0e2f07 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -15,16 +15,18 @@ class CustomPreprocessor(Preprocessor): - pass + def _init_shape(self, obs_space, options): + return None class CustomPreprocessor2(Preprocessor): - pass + def _init_shape(self, obs_space, options): + return None class CustomModel(Model): def _build_layers(self, *args): - return None, None + return tf.constant([[0] * 5]), None class ModelCatalogTest(unittest.TestCase): @@ -69,19 +71,24 @@ def testDefaultModels(self): ray.init() with tf.variable_scope("test1"): - p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5) + p1 = ModelCatalog.get_model({ + "obs": tf.zeros((10, 3), dtype=tf.float32) + }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): - p2 = ModelCatalog.get_model( - np.zeros((10, 84, 84, 3), dtype=np.float32), 5) + p2 = ModelCatalog.get_model({ + "obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32) + }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): ray.init() ModelCatalog.register_custom_model("foo", CustomModel) - p1 = ModelCatalog.get_model( - tf.constant([1, 2, 3]), 5, {"custom_model": "foo"}) + p1 = ModelCatalog.get_model({ + "obs": tf.constant([1, 2, 3]) + }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, + {"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel)) diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index cb371c90c29f8..aa8fac28086ab 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -23,7 +23,8 @@ def get_mean_action(alg, obs): "ES": { "episodes_per_batch": 10, "train_batch_size": 100, - "num_workers": 2 + "num_workers": 2, + "observation_filter": "MeanStdFilter" }, "DQN": {}, "APEX_DDPG": { @@ -46,6 +47,11 @@ def get_mean_action(alg, obs): "A3C": { "num_workers": 1 }, + "ARS": { + "num_rollouts": 10, + "num_workers": 2, + "observation_filter": "MeanStdFilter" + } } @@ -83,7 +89,7 @@ def test(use_object_store, alg_name, failures): if __name__ == "__main__": failures = [] for use_object_store in [False, True]: - for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG"]: + for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]: test(use_object_store, name, failures) assert not failures, failures diff --git a/python/ray/rllib/test/test_env_with_subprocess.py b/python/ray/rllib/test/test_env_with_subprocess.py new file mode 100644 index 0000000000000..fc940cdea05eb --- /dev/null +++ b/python/ray/rllib/test/test_env_with_subprocess.py @@ -0,0 +1,78 @@ +"""Tests that envs clean up after themselves on agent exit.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from gym.spaces import Discrete +import atexit +import gym +import os +import subprocess +import tempfile +import time + +import ray +from ray.tune import run_experiments +from ray.tune.registry import register_env + +# Dummy command to run as a subprocess with a unique name +UNIQUE_CMD = "sleep {}".format(str(time.time())) +_, UNIQUE_FILE_0 = tempfile.mkstemp("test_env_with_subprocess") +_, UNIQUE_FILE_1 = tempfile.mkstemp("test_env_with_subprocess") + + +class EnvWithSubprocess(gym.Env): + """Our env that spawns a subprocess.""" + + def __init__(self, config): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + # Subprocess that should be cleaned up + self.subproc = subprocess.Popen(UNIQUE_CMD.split(" "), shell=False) + # Exit handler should be called + if config.worker_index == 0: + atexit.register(lambda: os.unlink(UNIQUE_FILE_0)) + else: + atexit.register(lambda: os.unlink(UNIQUE_FILE_1)) + atexit.register(lambda: self.subproc.kill()) + + def reset(self): + return 0 + + def step(self, action): + return 0, 0, True, {} + + +def leaked_processes(): + """Returns whether any subprocesses were leaked.""" + result = subprocess.check_output( + "ps aux | grep '{}' | grep -v grep || true".format(UNIQUE_CMD), + shell=True) + return result + + +if __name__ == "__main__": + register_env("subproc", lambda config: EnvWithSubprocess(config)) + ray.init() + assert os.path.exists(UNIQUE_FILE_0) + assert os.path.exists(UNIQUE_FILE_1) + assert not leaked_processes() + run_experiments({ + "demo": { + "run": "PG", + "env": "subproc", + "num_samples": 1, + "config": { + "num_workers": 1, + }, + "stop": { + "training_iteration": 1 + }, + }, + }) + leaked = leaked_processes() + assert not leaked, "LEAKED PROCESSES: {}".format(leaked) + assert not os.path.exists(UNIQUE_FILE_0), "atexit handler not called" + assert not os.path.exists(UNIQUE_FILE_1), "atexit handler not called" + print("OK") diff --git a/python/ray/rllib/test/test_evaluators.py b/python/ray/rllib/test/test_evaluators.py index 9ae0994f33466..c7a72d7a5bb87 100644 --- a/python/ray/rllib/test/test_evaluators.py +++ b/python/ray/rllib/test/test_evaluators.py @@ -4,7 +4,7 @@ import unittest -from ray.rllib.agents.dqn.dqn_policy_graph import adjust_nstep +from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep class DQNTest(unittest.TestCase): @@ -14,7 +14,7 @@ def testNStep(self): rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0] new_obs = [2, 3, 4, 5, 6, 7, 8] dones = [0, 0, 0, 0, 0, 0, 1] - adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones) + _adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones) self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7]) self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"]) self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8]) diff --git a/python/ray/rllib/test/test_serving_env.py b/python/ray/rllib/test/test_external_env.py similarity index 88% rename from python/ray/rllib/test/test_serving_env.py rename to python/ray/rllib/test/test_external_env.py index 6f47eeeeeedd8..f7e8308a5ff17 100644 --- a/python/ray/rllib/test/test_serving_env.py +++ b/python/ray/rllib/test/test_external_env.py @@ -12,15 +12,15 @@ from ray.rllib.agents.dqn import DQNAgent from ray.rllib.agents.pg import PGAgent from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.env.serving_env import ServingEnv +from ray.rllib.env.external_env import ExternalEnv from ray.rllib.test.test_policy_evaluator import BadPolicyGraph, \ MockPolicyGraph, MockEnv from ray.tune.registry import register_env -class SimpleServing(ServingEnv): +class SimpleServing(ExternalEnv): def __init__(self, env): - ServingEnv.__init__(self, env.action_space, env.observation_space) + ExternalEnv.__init__(self, env.action_space, env.observation_space) self.env = env def run(self): @@ -36,9 +36,9 @@ def run(self): eid = self.start_episode() -class PartOffPolicyServing(ServingEnv): +class PartOffPolicyServing(ExternalEnv): def __init__(self, env, off_pol_frac): - ServingEnv.__init__(self, env.action_space, env.observation_space) + ExternalEnv.__init__(self, env.action_space, env.observation_space) self.env = env self.off_pol_frac = off_pol_frac @@ -59,9 +59,9 @@ def run(self): eid = self.start_episode() -class SimpleOffPolicyServing(ServingEnv): +class SimpleOffPolicyServing(ExternalEnv): def __init__(self, env, fixed_action): - ServingEnv.__init__(self, env.action_space, env.observation_space) + ExternalEnv.__init__(self, env.action_space, env.observation_space) self.env = env self.fixed_action = fixed_action @@ -79,12 +79,12 @@ def run(self): eid = self.start_episode() -class MultiServing(ServingEnv): +class MultiServing(ExternalEnv): def __init__(self, env_creator): self.env_creator = env_creator self.env = env_creator() - ServingEnv.__init__(self, self.env.action_space, - self.env.observation_space) + ExternalEnv.__init__(self, self.env.action_space, + self.env.observation_space) def run(self): envs = [self.env_creator() for _ in range(5)] @@ -107,8 +107,8 @@ def run(self): del cur_obs[i] -class TestServingEnv(unittest.TestCase): - def testServingEnvCompleteEpisodes(self): +class TestExternalEnv(unittest.TestCase): + def testExternalEnvCompleteEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, @@ -118,7 +118,7 @@ def testServingEnvCompleteEpisodes(self): batch = ev.sample() self.assertEqual(batch.count, 50) - def testServingEnvTruncateEpisodes(self): + def testExternalEnvTruncateEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, @@ -128,7 +128,7 @@ def testServingEnvTruncateEpisodes(self): batch = ev.sample() self.assertEqual(batch.count, 40) - def testServingEnvOffPolicy(self): + def testExternalEnvOffPolicy(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), policy_graph=MockPolicyGraph, @@ -140,7 +140,7 @@ def testServingEnvOffPolicy(self): self.assertEqual(batch["actions"][0], 42) self.assertEqual(batch["actions"][-1], 42) - def testServingEnvBadActions(self): + def testExternalEnvBadActions(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=BadPolicyGraph, @@ -185,15 +185,14 @@ def testTrainCartpoleMulti(self): return raise Exception("failed to improve reward") - def testServingEnvHorizonNotSupported(self): + def testExternalEnvHorizonNotSupported(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") - ev.sample() - self.assertRaises(Exception, lambda: ev.sample()) + self.assertRaises(ValueError, lambda: ev.sample()) if __name__ == '__main__': diff --git a/python/ray/rllib/test/test_lstm.py b/python/ray/rllib/test/test_lstm.py index 2abfb7680cd3d..abb9ad0ccb4a7 100644 --- a/python/ray/rllib/test/test_lstm.py +++ b/python/ray/rllib/test/test_lstm.py @@ -10,10 +10,12 @@ class LSTMUtilsTest(unittest.TestCase): def testBasic(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] + agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s, + 4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], @@ -22,11 +24,25 @@ def testBasic(self): self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]]) self.assertEqual(seq_lens.tolist(), [3, 4, 1]) + def testMultiAgent(self): + eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] + agent_ids = [1, 1, 2, 1, 1, 2, 2, 3] + f = [[101, 102, 103, 201, 202, 203, 204, 205], + [[101], [102], [103], [201], [202], [203], [204], [205]]] + s = [[209, 208, 207, 109, 108, 107, 106, 105]] + f_pad, s_init, seq_lens = chop_into_sequences( + eps_ids, agent_ids, f, s, 4, dynamic_max=False) + self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) + self.assertEqual(len(f_pad[0]), 20) + self.assertEqual(len(s_init[0]), 5) + def testDynamicMaxLen(self): eps_ids = [5, 2, 2] + agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s, + 4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2]) diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 96eaabaf1dff2..1fdfa5d74ae8a 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -15,12 +15,19 @@ from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \ MockPolicyGraph from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.tune.registry import register_env +def one_hot(i, n): + out = [0.0] * n + out[i] = 1.0 + return out + + class BasicMultiAgent(MultiAgentEnv): """Env of N independent agents, each of which exits after 25 steps.""" @@ -63,7 +70,7 @@ def __init__(self, num, increment_obs=False): self.last_info = {} self.i = 0 self.num = num - self.observation_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Discrete(10) self.action_space = gym.spaces.Discrete(2) def reset(self): @@ -99,25 +106,32 @@ def step(self, action_dict): return obs, rew, done, info -class MultiCartpole(MultiAgentEnv): - def __init__(self, num): - self.agents = [gym.make("CartPole-v0") for _ in range(num)] - self.dones = set() - self.observation_space = self.agents[0].observation_space - self.action_space = self.agents[0].action_space +def make_multiagent(env_name): + class MultiEnv(MultiAgentEnv): + def __init__(self, num): + self.agents = [gym.make(env_name) for _ in range(num)] + self.dones = set() + self.observation_space = self.agents[0].observation_space + self.action_space = self.agents[0].action_space - def reset(self): - self.dones = set() - return {i: a.reset() for i, a in enumerate(self.agents)} + def reset(self): + self.dones = set() + return {i: a.reset() for i, a in enumerate(self.agents)} - def step(self, action_dict): - obs, rew, done, info = {}, {}, {}, {} - for i, action in action_dict.items(): - obs[i], rew[i], done[i], info[i] = self.agents[i].step(action) - if done[i]: - self.dones.add(i) - done["__all__"] = len(self.dones) == len(self.agents) - return obs, rew, done, info + def step(self, action_dict): + obs, rew, done, info = {}, {}, {}, {} + for i, action in action_dict.items(): + obs[i], rew[i], done[i], info[i] = self.agents[i].step(action) + if done[i]: + self.dones.add(i) + done["__all__"] = len(self.dones) == len(self.agents) + return obs, rew, done, info + + return MultiEnv + + +MultiCartpole = make_multiagent("CartPole-v0") +MultiMountainCar = make_multiagent("MountainCarContinuous-v0") class TestMultiAgentEnv(unittest.TestCase): @@ -282,7 +296,7 @@ def testMultiAgentSampleWithHorizon(self): def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) - obs_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(10) ev = PolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_graph={ @@ -295,10 +309,20 @@ def testMultiAgentSampleRoundRobin(self): # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) - self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], - [0, 1, 2, 3, 4] * 2) - self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], - [1, 2, 3, 4, 5] * 2) + self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [ + one_hot(0, 10), + one_hot(1, 10), + one_hot(2, 10), + one_hot(3, 10), + one_hot(4, 10), + ] * 2) + self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [ + one_hot(1, 10), + one_hot(2, 10), + one_hot(3, 10), + one_hot(4, 10), + one_hot(5, 10), + ] * 2) self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10], @@ -306,12 +330,39 @@ def testMultiAgentSampleRoundRobin(self): self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) + def testCustomRNNStateValues(self): + h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} + + class StatefulPolicyGraph(PolicyGraph): + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None): + return [0] * len(obs_batch), [[h] * len(obs_batch)], {} + + def get_initial_state(self): + return [{}] # empty dict + + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=StatefulPolicyGraph, + batch_steps=5) + batch = ev.sample() + self.assertEqual(batch.count, 5) + self.assertEqual(batch["state_in_0"][0], {}) + self.assertEqual(batch["state_out_0"][0], h) + self.assertEqual(batch["state_in_0"][1], h) + self.assertEqual(batch["state_out_0"][1], h) + def testReturningModelBasedRolloutsData(self): class ModelBasedPolicyGraph(PGPolicyGraph): def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): # Pretend we did a model-based rollout and want to return # the extra trajectory. @@ -329,7 +380,7 @@ def compute_actions(self, dones=t == 4, infos={}, new_obs=obs_batch[0]) - batch = builder.build_and_reset() + batch = builder.build_and_reset(episode=None) episodes[0].add_extra_batch(batch) # Just return zeros for actions diff --git a/python/ray/rllib/test/test_nested_spaces.py b/python/ray/rllib/test/test_nested_spaces.py new file mode 100644 index 0000000000000..95744b7e278af --- /dev/null +++ b/python/ray/rllib/test/test_nested_spaces.py @@ -0,0 +1,346 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle + +from gym import spaces +from gym.envs.registration import EnvSpec +import gym +import tensorflow.contrib.slim as slim +import tensorflow as tf +import unittest + +import ray +from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.env import MultiAgentEnv +from ray.rllib.env.async_vector_env import AsyncVectorEnv +from ray.rllib.env.vector_env import VectorEnv +from ray.rllib.models import ModelCatalog +from ray.rllib.models.model import Model +from ray.rllib.test.test_external_env import SimpleServing +from ray.tune.registry import register_env + +DICT_SPACE = spaces.Dict({ + "sensors": spaces.Dict({ + "position": spaces.Box(low=-100, high=100, shape=(3, )), + "velocity": spaces.Box(low=-1, high=1, shape=(3, )), + "front_cam": spaces.Tuple( + (spaces.Box(low=0, high=1, shape=(10, 10, 3)), + spaces.Box(low=0, high=1, shape=(10, 10, 3)))), + "rear_cam": spaces.Box(low=0, high=1, shape=(10, 10, 3)), + }), + "inner_state": spaces.Dict({ + "charge": spaces.Discrete(100), + "job_status": spaces.Dict({ + "task": spaces.Discrete(5), + "progress": spaces.Box(low=0, high=100, shape=()), + }) + }) +}) + +DICT_SAMPLES = [DICT_SPACE.sample() for _ in range(10)] + +TUPLE_SPACE = spaces.Tuple([ + spaces.Box(low=-100, high=100, shape=(3, )), + spaces.Tuple((spaces.Box(low=0, high=1, shape=(10, 10, 3)), + spaces.Box(low=0, high=1, shape=(10, 10, 3)))), + spaces.Discrete(5), +]) + +TUPLE_SAMPLES = [TUPLE_SPACE.sample() for _ in range(10)] + + +def one_hot(i, n): + out = [0.0] * n + out[i] = 1.0 + return out + + +class NestedDictEnv(gym.Env): + def __init__(self): + self.action_space = spaces.Discrete(2) + self.observation_space = DICT_SPACE + self._spec = EnvSpec("NestedDictEnv-v0") + self.steps = 0 + + def reset(self): + self.steps = 0 + return DICT_SAMPLES[0] + + def step(self, action): + self.steps += 1 + return DICT_SAMPLES[self.steps], 1, self.steps >= 5, {} + + +class NestedTupleEnv(gym.Env): + def __init__(self): + self.action_space = spaces.Discrete(2) + self.observation_space = TUPLE_SPACE + self._spec = EnvSpec("NestedTupleEnv-v0") + self.steps = 0 + + def reset(self): + self.steps = 0 + return TUPLE_SAMPLES[0] + + def step(self, action): + self.steps += 1 + return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {} + + +class NestedMultiAgentEnv(MultiAgentEnv): + def __init__(self): + self.steps = 0 + + def reset(self): + return { + "dict_agent": DICT_SAMPLES[0], + "tuple_agent": TUPLE_SAMPLES[0], + } + + def step(self, actions): + self.steps += 1 + obs = { + "dict_agent": DICT_SAMPLES[self.steps], + "tuple_agent": TUPLE_SAMPLES[self.steps], + } + rew = { + "dict_agent": 0, + "tuple_agent": 0, + } + dones = {"__all__": self.steps >= 5} + infos = { + "dict_agent": {}, + "tuple_agent": {}, + } + return obs, rew, dones, infos + + +class InvalidModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + return "not", "valid" + + +class InvalidModel2(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + return tf.constant(0), tf.constant(0) + + +class DictSpyModel(Model): + capture_index = 0 + + def _build_layers_v2(self, input_dict, num_outputs, options): + def spy(pos, front_cam, task): + # TF runs this function in an isolated context, so we have to use + # redis to communicate back to our suite + ray.experimental.internal_kv._internal_kv_put( + "d_spy_in_{}".format(DictSpyModel.capture_index), + pickle.dumps((pos, front_cam, task)), + overwrite=True) + DictSpyModel.capture_index += 1 + return 0 + + spy_fn = tf.py_func( + spy, [ + input_dict["obs"]["sensors"]["position"], + input_dict["obs"]["sensors"]["front_cam"][0], + input_dict["obs"]["inner_state"]["job_status"]["task"] + ], + tf.int64, + stateful=True) + + with tf.control_dependencies([spy_fn]): + output = slim.fully_connected( + input_dict["obs"]["sensors"]["position"], num_outputs) + return output, output + + +class TupleSpyModel(Model): + capture_index = 0 + + def _build_layers_v2(self, input_dict, num_outputs, options): + def spy(pos, cam, task): + # TF runs this function in an isolated context, so we have to use + # redis to communicate back to our suite + ray.experimental.internal_kv._internal_kv_put( + "t_spy_in_{}".format(TupleSpyModel.capture_index), + pickle.dumps((pos, cam, task)), + overwrite=True) + TupleSpyModel.capture_index += 1 + return 0 + + spy_fn = tf.py_func( + spy, [ + input_dict["obs"][0], + input_dict["obs"][1][0], + input_dict["obs"][2], + ], + tf.int64, + stateful=True) + + with tf.control_dependencies([spy_fn]): + output = slim.fully_connected(input_dict["obs"][0], num_outputs) + return output, output + + +class NestedSpacesTest(unittest.TestCase): + def testInvalidModel(self): + ModelCatalog.register_custom_model("invalid", InvalidModel) + self.assertRaises(ValueError, lambda: PGAgent( + env="CartPole-v0", config={ + "model": { + "custom_model": "invalid", + }, + })) + + def testInvalidModel2(self): + ModelCatalog.register_custom_model("invalid2", InvalidModel2) + self.assertRaisesRegexp( + ValueError, "Expected output.*", + lambda: PGAgent( + env="CartPole-v0", config={ + "model": { + "custom_model": "invalid2", + }, + })) + + def doTestNestedDict(self, make_env, test_lstm=False): + ModelCatalog.register_custom_model("composite", DictSpyModel) + register_env("nested", make_env) + pg = PGAgent( + env="nested", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "model": { + "custom_model": "composite", + "use_lstm": test_lstm, + }, + }) + pg.train() + + # Check that the model sees the correct reconstructed observations + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "d_spy_in_{}".format(i))) + pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() + cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() + task_i = one_hot( + DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + def doTestNestedTuple(self, make_env): + ModelCatalog.register_custom_model("composite2", TupleSpyModel) + register_env("nested2", make_env) + pg = PGAgent( + env="nested2", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "model": { + "custom_model": "composite2", + }, + }) + pg.train() + + # Check that the model sees the correct reconstructed observations + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "t_spy_in_{}".format(i))) + pos_i = TUPLE_SAMPLES[i][0].tolist() + cam_i = TUPLE_SAMPLES[i][1][0].tolist() + task_i = one_hot(TUPLE_SAMPLES[i][2], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + def testNestedDictGym(self): + self.doTestNestedDict(lambda _: NestedDictEnv()) + + def testNestedDictGymLSTM(self): + self.doTestNestedDict(lambda _: NestedDictEnv(), test_lstm=True) + + def testNestedDictVector(self): + self.doTestNestedDict( + lambda _: VectorEnv.wrap(lambda i: NestedDictEnv())) + + def testNestedDictServing(self): + self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv())) + + def testNestedDictAsync(self): + self.doTestNestedDict( + lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv())) + + def testNestedTupleGym(self): + self.doTestNestedTuple(lambda _: NestedTupleEnv()) + + def testNestedTupleVector(self): + self.doTestNestedTuple( + lambda _: VectorEnv.wrap(lambda i: NestedTupleEnv())) + + def testNestedTupleServing(self): + self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv())) + + def testNestedTupleAsync(self): + self.doTestNestedTuple( + lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv())) + + def testMultiAgentComplexSpaces(self): + ModelCatalog.register_custom_model("dict_spy", DictSpyModel) + ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel) + register_env("nested_ma", lambda _: NestedMultiAgentEnv()) + act_space = spaces.Discrete(2) + pg = PGAgent( + env="nested_ma", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "multiagent": { + "policy_graphs": { + "tuple_policy": ( + PGPolicyGraph, TUPLE_SPACE, act_space, + {"model": {"custom_model": "tuple_spy"}}), + "dict_policy": ( + PGPolicyGraph, DICT_SPACE, act_space, + {"model": {"custom_model": "dict_spy"}}), + }, + "policy_mapping_fn": lambda a: { + "tuple_agent": "tuple_policy", + "dict_agent": "dict_policy"}[a], + }, + }) + pg.train() + + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "d_spy_in_{}".format(i))) + pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() + cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() + task_i = one_hot( + DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "t_spy_in_{}".format(i))) + pos_i = TUPLE_SAMPLES[i][0].tolist() + cam_i = TUPLE_SAMPLES[i][1][0].tolist() + task_i = one_hot(TUPLE_SAMPLES[i][2], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + +if __name__ == "__main__": + ray.init(num_cpus=5) + unittest.main(verbosity=2) diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index cc189edbf6e0c..cf319a7e922b2 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -3,8 +3,10 @@ from __future__ import print_function import gym +import numpy as np import time import unittest +from collections import Counter import ray from ray.rllib.agents.pg import PGAgent @@ -21,11 +23,16 @@ class MockPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): return [0] * len(obs_batch), [], {} - def postprocess_trajectory(self, batch, other_agent_batches=None): + def postprocess_trajectory(self, + batch, + other_agent_batches=None, + episode=None): + assert episode is not None return compute_advantages(batch, 100.0, 0.9, use_gae=False) @@ -33,14 +40,31 @@ class BadPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, - is_training=False, + prev_action_batch=None, + prev_reward_batch=None, episodes=None): raise Exception("intentional error") - def postprocess_trajectory(self, batch, other_agent_batches=None): + def postprocess_trajectory(self, + batch, + other_agent_batches=None, + episode=None): + assert episode is not None return compute_advantages(batch, 100.0, 0.9, use_gae=False) +class FailOnStepEnv(gym.Env): + def __init__(self): + self.observation_space = gym.spaces.Discrete(1) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + raise ValueError("kaboom") + + def step(self, action): + raise ValueError("kaboom") + + class MockEnv(gym.Env): def __init__(self, episode_length, config=None): self.episode_length = episode_length @@ -107,10 +131,39 @@ def testBasic(self): env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch = ev.sample() - for key in ["obs", "actions", "rewards", "dones", "advantages"]: + for key in [ + "obs", "actions", "rewards", "dones", "advantages", + "prev_rewards", "prev_actions" + ]: self.assertIn(key, batch) + + def to_prev(vec): + out = np.zeros_like(vec) + for i, v in enumerate(vec): + if i + 1 < len(out) and not batch["dones"][i]: + out[i + 1] = v + return out.tolist() + + self.assertEqual(batch["prev_rewards"].tolist(), + to_prev(batch["rewards"])) + self.assertEqual(batch["prev_actions"].tolist(), + to_prev(batch["actions"])) self.assertGreater(batch["advantages"][0], 1) + # 11/23/18: Samples per second 8501.125113727468 + def testBaselinePerformance(self): + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=MockPolicyGraph, + batch_steps=100) + start = time.time() + count = 0 + while time.time() - start < 1: + count += ev.sample().count + print() + print("Samples per second {}".format(count / (time.time() - start))) + print() + def testGlobalVarsUpdate(self): agent = A2CAgent( env="CartPole-v0", @@ -122,6 +175,34 @@ def testGlobalVarsUpdate(self): result2 = agent.train() self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001) + def testNoStepOnInit(self): + register_env("fail", lambda _: FailOnStepEnv()) + pg = PGAgent(env="fail", config={"num_workers": 1}) + self.assertRaises(Exception, lambda: pg.train()) + + def testCallbacks(self): + counts = Counter() + pg = PGAgent( + env="CartPole-v0", config={ + "num_workers": 0, + "sample_batch_size": 50, + "callbacks": { + "on_episode_start": lambda x: counts.update({"start": 1}), + "on_episode_step": lambda x: counts.update({"step": 1}), + "on_episode_end": lambda x: counts.update({"end": 1}), + "on_sample_end": lambda x: counts.update({"sample": 1}), + }, + }) + pg.train() + pg.train() + pg.train() + pg.train() + self.assertEqual(counts["sample"], 4) + self.assertGreater(counts["start"], 0) + self.assertGreater(counts["end"], 0) + self.assertGreater(counts["step"], 200) + self.assertLess(counts["step"], 400) + def testQueryEvaluators(self): register_env("test", lambda _: gym.make("CartPole-v0")) pg = PGAgent( @@ -129,9 +210,10 @@ def testQueryEvaluators(self): "num_workers": 2, "sample_batch_size": 5 }) - results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps) + results = pg.optimizer.foreach_evaluator( + lambda ev: ev.sample_batch_size) results2 = pg.optimizer.foreach_evaluator_with_index( - lambda ev, i: (i, ev.batch_steps)) + lambda ev, i: (i, ev.sample_batch_size)) self.assertEqual(results, [5, 5, 5]) self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)]) @@ -198,7 +280,7 @@ def testAutoVectorization(self): env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", - batch_steps=16, + batch_steps=2, num_envs=8) for _ in range(8): batch = ev.sample() @@ -216,21 +298,12 @@ def testAutoVectorization(self): indices.append(env.unwrapped.config.vector_index) self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7]) - def testBatchDivisibilityCheck(self): - self.assertRaises( - ValueError, - lambda: PolicyEvaluator( - env_creator=lambda _: MockEnv(episode_length=8), - policy_graph=MockPolicyGraph, - batch_mode="truncate_episodes", - batch_steps=15, num_envs=4)) - - def testBatchesSmallerWhenVectorized(self): + def testBatchesLargerWhenVectorized(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", - batch_steps=16, + batch_steps=4, num_envs=4) batch = ev.sample() self.assertEqual(batch.count, 16) diff --git a/python/ray/rllib/test/test_rollout.sh b/python/ray/rllib/test/test_rollout.sh new file mode 100755 index 0000000000000..04685b2be345d --- /dev/null +++ b/python/ray/rllib/test/test_rollout.sh @@ -0,0 +1,28 @@ +#!/bin/bash -e + +TRAIN=/ray/python/ray/rllib/train.py +if [ ! -e "$TRAIN" ]; then + TRAIN=../train.py +fi +ROLLOUT=/ray/python/ray/rllib/rollout.py +if [ ! -e "$ROLLOUT" ]; then + ROLLOUT=../rollout.py +fi + +TMP=`mktemp -d` +echo "Saving results to $TMP" + +$TRAIN --local-dir=$TMP --run=IMPALA --checkpoint-freq=1 \ + --config='{"num_workers": 1, "num_gpus": 0}' --env=Pong-ram-v4 \ + --stop='{"training_iteration": 1}' +find $TMP + +CHECKPOINT_PATH=`ls $TMP/default/*/checkpoint_1/checkpoint-1` +echo "Checkpoint path $CHECKPOINT_PATH" +test -e "$CHECKPOINT_PATH" + +$ROLLOUT --run=IMPALA "$CHECKPOINT_PATH" --steps=100 \ + --out="$TMP/rollouts.pkl" --no-render +test -e "$TMP/rollouts.pkl" +rm -rf "$TMP" +echo "OK" diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 2ced3402a78a6..2e5b74b536f90 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -2,49 +2,41 @@ import traceback import gym -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Discrete, Tuple, Dict from gym.envs.registration import EnvSpec import numpy as np import sys import ray from ray.rllib.agents.agent import get_agent_class +from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar from ray.rllib.utils.error import UnsupportedSpaceException from ray.tune.registry import register_env ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5, ), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ]), - "implicit_tuple": [ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ], - "mixed_tuple": Tuple( + "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "tuple": Tuple( [Discrete(2), Discrete(3), - Box(0.0, 1.0, (5, ), dtype=np.float32)]), + Box(-1.0, 1.0, (5, ), dtype=np.float32)]), } OBSERVATION_SPACES_TO_TEST = { "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5, ), dtype=np.float32), - "image": Box(0.0, 1.0, (84, 84, 1), dtype=np.float32), - "atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32), - "atari_ram": Box(0.0, 1.0, (128, ), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ]), - "mixed_tuple": Tuple( - [Discrete(10), Box(0.0, 1.0, (5, ), dtype=np.float32)]), + "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32), + "atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32), + "tuple": Tuple([Discrete(10), + Box(-1.0, 1.0, (5, ), dtype=np.float32)]), + "dict": Dict({ + "task": Discrete(10), + "position": Box(-1.0, 1.0, (5, ), dtype=np.float32), + }), } -def make_stub_env(action_space, obs_space): +def make_stub_env(action_space, obs_space, check_action_bounds): class StubEnv(gym.Env): def __init__(self): self.action_space = action_space @@ -56,23 +48,30 @@ def reset(self): return sample def step(self, action): + if check_action_bounds and not self.action_space.contains(action): + raise ValueError("Illegal action for {}: {}".format( + self.action_space, action)) + if (isinstance(self.action_space, Tuple) + and len(action) != len(self.action_space.spaces)): + raise ValueError("Illegal action for {}: {}".format( + self.action_space, action)) return self.observation_space.sample(), 1, True, {} return StubEnv -def check_support(alg, config, stats): +def check_support(alg, config, stats, check_bounds=False): for a_name, action_space in ACTION_SPACES_TO_TEST.items(): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing", alg, action_space, obs_space, "===") - stub_env = make_stub_env(action_space, obs_space) + stub_env = make_stub_env(action_space, obs_space, check_bounds) register_env("stub_env", lambda c: stub_env()) stat = "ok" a = None try: a = get_agent_class(alg)(config=config, env="stub_env") a.train() - except UnsupportedSpaceException as e: + except UnsupportedSpaceException: stat = "unsupported" except Exception as e: stat = "ERROR" @@ -90,27 +89,56 @@ def check_support(alg, config, stats): stats[alg, a_name, o_name] = stat +def check_support_multiagent(alg, config): + register_env("multi_mountaincar", lambda _: MultiMountainCar(2)) + register_env("multi_cartpole", lambda _: MultiCartpole(2)) + if alg == "DDPG": + a = get_agent_class(alg)(config=config, env="multi_mountaincar") + else: + a = get_agent_class(alg)(config=config, env="multi_cartpole") + try: + a.train() + finally: + a.stop() + + class ModelSupportedSpaces(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4) + + def tearDown(self): + ray.shutdown() + def testAll(self): - ray.init() stats = {} - check_support("IMPALA", {"gpu": False}, stats) - check_support("DDPG", {"timesteps_per_iteration": 1}, stats) + check_support("IMPALA", {"num_gpus": 0}, stats) + check_support( + "DDPG", { + "noise_scale": 100.0, + "timesteps_per_iteration": 1 + }, + stats, + check_bounds=True) check_support("DQN", {"timesteps_per_iteration": 1}, stats) - check_support("A3C", { - "num_workers": 1, - "optimizer": { - "grads_per_step": 1 - } - }, stats) + check_support( + "A3C", { + "num_workers": 1, + "optimizer": { + "grads_per_step": 1 + } + }, + stats, + check_bounds=True) check_support( "PPO", { "num_workers": 1, "num_sgd_iter": 1, "train_batch_size": 10, "sample_batch_size": 10, - "sgd_minibatch_size": 1 - }, stats) + "sgd_minibatch_size": 1, + }, + stats, + check_bounds=True) check_support( "ES", { "num_workers": 1, @@ -125,7 +153,13 @@ def testAll(self): "num_rollouts": 1, "rollouts_used": 1 }, stats) - check_support("PG", {"num_workers": 1, "optimizer": {}}, stats) + check_support( + "PG", { + "num_workers": 1, + "optimizer": {} + }, + stats, + check_bounds=True) num_unexpected_errors = 0 for (alg, a_name, o_name), stat in sorted(stats.items()): if stat not in ["ok", "unsupported"]: @@ -134,6 +168,26 @@ def testAll(self): stat) self.assertEqual(num_unexpected_errors, 0) + def testMultiAgent(self): + check_support_multiagent("IMPALA", {"num_gpus": 0}) + check_support_multiagent("DQN", {"timesteps_per_iteration": 1}) + check_support_multiagent("A3C", { + "num_workers": 1, + "optimizer": { + "grads_per_step": 1 + } + }) + check_support_multiagent( + "PPO", { + "num_workers": 1, + "num_sgd_iter": 1, + "train_batch_size": 10, + "sample_batch_size": 10, + "sgd_minibatch_size": 1, + }) + check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}}) + check_support_multiagent("DDPG", {"timesteps_per_iteration": 1}) + if __name__ == "__main__": if len(sys.argv) > 1 and sys.argv[1] == "--smoke": diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index b9e7b72efd673..72d6fc0b58c3d 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -8,6 +8,7 @@ import yaml import ray +from ray.test.cluster_utils import Cluster from ray.tune.config_parser import make_parser, resources_to_json from ray.tune.tune import _make_scheduler, run_experiments @@ -37,19 +38,33 @@ def create_parser(parser_creator=None): "--redis-address", default=None, type=str, - help="The Redis address of the cluster.") + help="Connect to an existing Ray cluster at this address instead " + "of starting a new one.") parser.add_argument( "--ray-num-cpus", default=None, type=int, - help="--num-cpus to pass to Ray." - " This only has an affect in local mode.") + help="--num-cpus to use if starting a new cluster.") parser.add_argument( "--ray-num-gpus", default=None, type=int, - help="--num-gpus to pass to Ray." - " This only has an affect in local mode.") + help="--num-gpus to use if starting a new cluster.") + parser.add_argument( + "--ray-num-local-schedulers", + default=None, + type=int, + help="Emulate multiple cluster nodes for debugging.") + parser.add_argument( + "--ray-redis-max-memory", + default=None, + type=int, + help="--redis-max-memory to use if starting a new cluster.") + parser.add_argument( + "--ray-object-store-memory", + default=None, + type=int, + help="--object-store-memory to use if starting a new cluster.") parser.add_argument( "--experiment-name", default="default", @@ -102,10 +117,24 @@ def run(args, parser): if not exp.get("env") and not exp.get("config", {}).get("env"): parser.error("the following arguments are required: --env") - ray.init( - redis_address=args.redis_address, - num_cpus=args.ray_num_cpus, - num_gpus=args.ray_num_gpus) + if args.ray_num_local_schedulers: + cluster = Cluster() + for _ in range(args.ray_num_local_schedulers): + cluster.add_node( + resources={ + "num_cpus": args.ray_num_cpus or 1, + "num_gpus": args.ray_num_gpus or 0, + }, + object_store_memory=args.ray_object_store_memory, + redis_max_memory=args.ray_redis_max_memory) + ray.init(redis_address=cluster.redis_address) + else: + ray.init( + redis_address=args.redis_address, + object_store_memory=args.ray_object_store_memory, + redis_max_memory=args.ray_redis_max_memory, + num_cpus=args.ray_num_cpus, + num_gpus=args.ray_num_gpus) run_experiments( experiments, scheduler=_make_scheduler(args), diff --git a/python/ray/rllib/tuned_examples/atari-a2c.yaml b/python/ray/rllib/tuned_examples/atari-a2c.yaml index 89feaee5ba8b6..42ea119638e62 100644 --- a/python/ray/rllib/tuned_examples/atari-a2c.yaml +++ b/python/ray/rllib/tuned_examples/atari-a2c.yaml @@ -9,11 +9,11 @@ atari-a2c: - SpaceInvadersNoFrameskip-v4 run: A2C config: - sample_batch_size: 100 + sample_batch_size: 20 clip_rewards: True num_workers: 5 num_envs_per_worker: 5 - gpu: true + num_gpus: 1 lr_schedule: [ [0, 0.0007], [20000000, 0.000000000001], diff --git a/python/ray/rllib/tuned_examples/atari-apex.yaml b/python/ray/rllib/tuned_examples/atari-apex.yaml index 6e538d038998a..e24e347dd18aa 100644 --- a/python/ray/rllib/tuned_examples/atari-apex.yaml +++ b/python/ray/rllib/tuned_examples/atari-apex.yaml @@ -23,12 +23,12 @@ apex: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - gpu: false + num_gpus: 1 # APEX num_workers: 8 num_envs_per_worker: 8 - sample_batch_size: 158 + sample_batch_size: 20 train_batch_size: 512 target_network_update_freq: 50000 timesteps_per_iteration: 25000 diff --git a/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml b/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml index d719329861e88..d351e403f2e23 100644 --- a/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-dist-dqn.yaml @@ -27,5 +27,5 @@ basic-dqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - gpu: true + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/atari-dqn.yaml b/python/ray/rllib/tuned_examples/atari-dqn.yaml index 4929017879c97..b8731bb054ef3 100644 --- a/python/ray/rllib/tuned_examples/atari-dqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-dqn.yaml @@ -1,4 +1,4 @@ -# Runs on a single g3.16xl node +# Runs on a single g3.4xl node # See https://github.com/ray-project/rl-experiments for results atari-basic-dqn: env: @@ -29,5 +29,5 @@ atari-basic-dqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - gpu: true + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml b/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml index 61ed3120de1ac..b5a13162b61e4 100644 --- a/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml +++ b/python/ray/rllib/tuned_examples/atari-duel-ddqn.yaml @@ -1,3 +1,5 @@ +# Runs on a single g3.4xl node +# See https://github.com/ray-project/rl-experiments for results dueling-ddqn: env: grid_search: @@ -27,5 +29,5 @@ dueling-ddqn: prioritized_replay_alpha: 0.5 beta_annealing_fraction: 1.0 final_prioritized_replay_beta: 1.0 - gpu: true + num_gpus: 0.2 timesteps_per_iteration: 10000 diff --git a/python/ray/rllib/tuned_examples/atari-impala.yaml b/python/ray/rllib/tuned_examples/atari-impala.yaml index 85bd801ff83b5..597b41987b3f6 100644 --- a/python/ray/rllib/tuned_examples/atari-impala.yaml +++ b/python/ray/rllib/tuned_examples/atari-impala.yaml @@ -9,7 +9,7 @@ atari-impala: - SpaceInvadersNoFrameskip-v4 run: IMPALA config: - sample_batch_size: 250 # 50 * num_envs_per_worker + sample_batch_size: 50 train_batch_size: 500 num_workers: 32 num_envs_per_worker: 5 diff --git a/python/ray/rllib/tuned_examples/atari-ppo.yaml b/python/ray/rllib/tuned_examples/atari-ppo.yaml index 24593d6bb9299..c6be6435041ca 100644 --- a/python/ray/rllib/tuned_examples/atari-ppo.yaml +++ b/python/ray/rllib/tuned_examples/atari-ppo.yaml @@ -16,7 +16,7 @@ atari-ppo: vf_clip_param: 10.0 entropy_coeff: 0.01 train_batch_size: 5000 - sample_batch_size: 500 + sample_batch_size: 100 sgd_minibatch_size: 500 num_sgd_iter: 10 num_workers: 10 diff --git a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml index 34c60e5219b4e..f02399ab33ff2 100644 --- a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml @@ -34,8 +34,9 @@ halfcheetah-ddpg: clip_rewards: False # === Optimization === - actor_lr: 0.0001 - critic_lr: 0.001 + lr: 0.001 + actor_loss_coeff: 0.1 + critic_loss_coeff: 1.0 use_huber: False huber_threshold: 1.0 l2_reg: 0.000001 diff --git a/python/ray/rllib/tuned_examples/hopper-ppo.yaml b/python/ray/rllib/tuned_examples/hopper-ppo.yaml index c1c75b166e7cd..5082dc7921e47 100644 --- a/python/ray/rllib/tuned_examples/hopper-ppo.yaml +++ b/python/ray/rllib/tuned_examples/hopper-ppo.yaml @@ -10,3 +10,4 @@ hopper-ppo: train_batch_size: 160000 num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml index e176dcae26c67..9473b5df7a6a3 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml @@ -17,3 +17,4 @@ humanoid-ppo-gae: free_log_std: true num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml index 0608f8b60353c..07371d16f712b 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml @@ -15,3 +15,4 @@ humanoid-ppo: use_gae: false num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml index a71b1e98ff658..e74b2e0f138e8 100644 --- a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml @@ -34,8 +34,9 @@ mountaincarcontinuous-ddpg: clip_rewards: False # === Optimization === - actor_lr: 0.0001 - critic_lr: 0.001 + lr: 0.001 + actor_loss_coeff: 0.1 + critic_loss_coeff: 1.0 use_huber: False huber_threshold: 1.0 l2_reg: 0.00001 diff --git a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml index 3cf68bcdc23e4..e28eee3e8e216 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml @@ -34,8 +34,9 @@ pendulum-ddpg: clip_rewards: False # === Optimization === - actor_lr: 0.0001 - critic_lr: 0.001 + lr: 0.001 + actor_loss_coeff: 0.1 + critic_loss_coeff: 1.0 use_huber: True huber_threshold: 1.0 l2_reg: 0.000001 diff --git a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml index 60df6825bd435..b8c0293a3e338 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml @@ -13,4 +13,4 @@ pendulum-ppo: num_sgd_iter: 10 model: fcnet_hiddens: [64, 64] - squash_to_range: True + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/pendulum-td3.yaml b/python/ray/rllib/tuned_examples/pendulum-td3.yaml new file mode 100644 index 0000000000000..25b0900d63c2a --- /dev/null +++ b/python/ray/rllib/tuned_examples/pendulum-td3.yaml @@ -0,0 +1,60 @@ +# This configuration can expect to reach -160 reward in 10k-20k timesteps +pendulum-ddpg: + env: Pendulum-v0 + run: DDPG + stop: + episode_reward_mean: -160 + time_total_s: 600 # 10 minutes + config: + # === Tricks === + twin_q: True + policy_delay: 2 + smooth_target_policy: True + act_noise: 0.1 + target_noise: 0.2 + noise_clip: 0.5 + + # === Model === + actor_hiddens: [64, 64] + critic_hiddens: [64, 64] + n_step: 1 + model: {} + gamma: 0.99 + env_config: {} + + # === Exploration === + schedule_max_timesteps: 100000 + timesteps_per_iteration: 600 + exploration_fraction: 0.1 + exploration_final_eps: 0.02 + noise_scale: 0.1 + exploration_theta: 0.15 + exploration_sigma: 0.2 + target_network_update_freq: 0 + tau: 0.001 + + # === Replay buffer === + buffer_size: 10000 + prioritized_replay: True + prioritized_replay_alpha: 0.6 + prioritized_replay_beta: 0.4 + prioritized_replay_eps: 0.000001 + clip_rewards: False + + # === Optimization === + lr: 0.001 + actor_loss_coeff: 0.1 + critic_loss_coeff: 1.0 + use_huber: True + huber_threshold: 1.0 + l2_reg: 0.000001 + learning_starts: 500 + sample_batch_size: 1 + train_batch_size: 64 + + # === Parallelism === + num_workers: 0 + num_gpus_per_worker: 0 + optimizer_class: "SyncReplayOptimizer" + per_worker_exploration: False + worker_side_prioritization: False diff --git a/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml index 891c4b9919277..c3f608ddccb66 100644 --- a/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml +++ b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml @@ -15,7 +15,7 @@ pong-a3c-pytorch-cnn: model: use_lstm: false channel_major: true - dim: 80 + dim: 84 grayscale: true zero_mean: false optimizer: diff --git a/python/ray/rllib/tuned_examples/pong-dqn.yaml b/python/ray/rllib/tuned_examples/pong-dqn.yaml index a0d39cc3dadc2..2c3e5a877ed4c 100644 --- a/python/ray/rllib/tuned_examples/pong-dqn.yaml +++ b/python/ray/rllib/tuned_examples/pong-dqn.yaml @@ -6,7 +6,7 @@ pong-deterministic-dqn: episode_reward_mean: 20 time_total_s: 7200 config: - gpu: True + num_gpus: 1 gamma: 0.99 lr: .0001 learning_starts: 10000 diff --git a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml new file mode 100644 index 0000000000000..3c29f4e0c08e4 --- /dev/null +++ b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml @@ -0,0 +1,19 @@ +# This can reach 18-19 reward in ~3 minutes on p3.16xl head w/m4.16xl workers +# 128 workers -> 3 minutes (best case) +# 64 workers -> 4 minutes +# 32 workers -> 7 minutes +# See also: pong-impala.yaml, pong-impala-vectorized.yaml +pong-impala-fast: + env: PongNoFrameskip-v4 + run: IMPALA + config: + sample_batch_size: 50 + train_batch_size: 1000 + num_workers: 128 + num_envs_per_worker: 5 + broadcast_interval: 5 + max_sample_requests_in_flight_per_worker: 1 + num_parallel_data_loaders: 4 + num_gpus: 2 + model: + dim: 42 diff --git a/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml index 9525f4115521e..b16488b443b80 100644 --- a/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml @@ -5,7 +5,7 @@ pong-impala-vectorized: env: PongNoFrameskip-v4 run: IMPALA config: - sample_batch_size: 500 # 50 * num_envs_per_worker + sample_batch_size: 50 train_batch_size: 500 num_workers: 32 num_envs_per_worker: 10 diff --git a/python/ray/rllib/tuned_examples/pong-impala.yaml b/python/ray/rllib/tuned_examples/pong-impala.yaml index b54c79849c5ab..527bc905d8e57 100644 --- a/python/ray/rllib/tuned_examples/pong-impala.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala.yaml @@ -2,7 +2,7 @@ # 128 workers -> 8 minutes # 32 workers -> 17 minutes # 16 workers -> 40 min+ -# See also: pong-impala-vectorized.yaml +# See also: pong-impala-fast.yaml, pong-impala-vectorized.yaml pong-impala: env: PongNoFrameskip-v4 run: IMPALA diff --git a/python/ray/rllib/tuned_examples/pong-ppo.yaml b/python/ray/rllib/tuned_examples/pong-ppo.yaml index 1447481643fe5..d7e273cc6e2bd 100644 --- a/python/ray/rllib/tuned_examples/pong-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pong-ppo.yaml @@ -1,17 +1,26 @@ -# On a Tesla K80 GPU, this achieves the maximum reward in about 1-1.5 hours. +# On a single GPU, this achieves maximum reward in ~15-20 minutes. # -# $ python train.py -f tuned_examples/pong-ppo.yaml --ray-num-gpus=1 +# $ python train.py -f tuned_examples/pong-ppo.yaml # -# - PPO_PongDeterministic-v4_0: TERMINATED [pid=16387], 4984 s, 1117981 ts, 21 rew -# - PPO_PongDeterministic-v4_0: TERMINATED [pid=83606], 4592 s, 1068671 ts, 21 rew -# -pong-deterministic-ppo: - env: PongDeterministic-v4 +pong-ppo: + env: PongNoFrameskip-v4 run: PPO - stop: - episode_reward_mean: 21 config: - gamma: 0.99 - num_workers: 4 - num_sgd_iter: 20 + lambda: 0.95 + kl_coeff: 0.5 + clip_rewards: True + clip_param: 0.1 + vf_clip_param: 10.0 + entropy_coeff: 0.01 + train_batch_size: 5000 + sample_batch_size: 20 + sgd_minibatch_size: 500 + num_sgd_iter: 10 + num_workers: 32 + num_envs_per_worker: 5 + batch_mode: truncate_episodes + observation_filter: NoFilter + vf_share_layers: true num_gpus: 1 + model: + dim: 42 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml index 425958e5c109f..82ea5846e733c 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml @@ -6,3 +6,4 @@ cartpole-ppo: time_total_s: 300 config: num_workers: 1 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml index 8b9d69fce20a6..63536d3be3704 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml @@ -15,3 +15,4 @@ pendulum-ppo: num_sgd_iter: 10 model: fcnet_hiddens: [64, 64] + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/swimmer-ars.yaml b/python/ray/rllib/tuned_examples/swimmer-ars.yaml index 338c8a12c2cfc..effb4cfe19a8a 100644 --- a/python/ray/rllib/tuned_examples/swimmer-ars.yaml +++ b/python/ray/rllib/tuned_examples/swimmer-ars.yaml @@ -9,8 +9,9 @@ swimmer-ars: num_workers: 1 sgd_stepsize: 0.02 noise_size: 250000000 - policy_type: LinearPolicy eval_prob: 0.2 offset: 0 observation_filter: NoFilter report_length: 3 + model: + fcnet_hiddens: [] # a linear policy diff --git a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml index deb5a0038dcb7..9d64720a2c5b6 100644 --- a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml +++ b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml @@ -9,3 +9,4 @@ walker2d-v1-ppo: train_batch_size: 320000 num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index e865feb431b4b..7018073313112 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -2,9 +2,12 @@ from __future__ import division from __future__ import print_function +import logging import os import ray +logger = logging.getLogger(__name__) + class TaskPool(object): """Helper class for tracking the status of many in-flight actor tasks.""" @@ -36,11 +39,8 @@ def completed_prefetch(self): for worker, obj_id in self.completed(): plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) - if not ray.global_state.use_raylet: - ray.worker.global_worker.plasma_client.fetch([plasma_id]) - else: - (ray.worker.global_worker.local_scheduler_client. - reconstruct_objects([obj_id], True)) + (ray.worker.global_worker.local_scheduler_client. + fetch_or_reconstruct([obj_id], True)) self._fetching.append((worker, obj_id)) remaining = [] @@ -80,11 +80,12 @@ def split_colocated(actors): def try_create_colocated(cls, args, count): actors = [cls.remote(*args) for _ in range(count)] local, _ = split_colocated(actors) - print("Got {} colocated actors of {}".format(len(local), count)) + logger.info("Got {} colocated actors of {}".format(len(local), count)) return local def create_colocated(cls, args, count): + logger.info("Trying to create {} colocated actors".format(count)) ok = [] i = 1 while len(ok) < count and i < 10: diff --git a/python/ray/rllib/utils/annotations.py b/python/ray/rllib/utils/annotations.py new file mode 100644 index 0000000000000..d68f76a69600e --- /dev/null +++ b/python/ray/rllib/utils/annotations.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def override(cls): + """Annotation for documenting method overrides. + + Arguments: + cls (type): The superclass that provides the overriden method. If this + cls does not actually have the method, an error is raised. + """ + + def check_override(method): + if method.__name__ not in dir(cls): + raise NameError("{} does not override any method of {}".format( + method, cls)) + return method + + return check_override diff --git a/python/ray/rllib/utils/compression.py b/python/ray/rllib/utils/compression.py index 5f28455ee44aa..aed0dd5985600 100644 --- a/python/ray/rllib/utils/compression.py +++ b/python/ray/rllib/utils/compression.py @@ -2,18 +2,21 @@ from __future__ import division from __future__ import print_function +import logging import time import base64 import numpy as np import pyarrow +logger = logging.getLogger(__name__) + try: import lz4.frame LZ4_ENABLED = True except ImportError: - print("WARNING: lz4 not available, disabling sample compression. " - "This will significantly impact RLlib performance. " - "To install lz4, run `pip install lz4`.") + logger.warn("lz4 not available, disabling sample compression. " + "This will significantly impact RLlib performance. " + "To install lz4, run `pip install lz4`.") LZ4_ENABLED = False diff --git a/python/ray/rllib/utils/filter.py b/python/ray/rllib/utils/filter.py index b2a3619481cdd..9a1f37dbd15a5 100644 --- a/python/ray/rllib/utils/filter.py +++ b/python/ray/rllib/utils/filter.py @@ -2,9 +2,12 @@ from __future__ import division from __future__ import print_function +import logging import numpy as np import threading +logger = logging.getLogger(__name__) + class Filter(object): """Processes input, possibly statefully.""" @@ -39,7 +42,10 @@ def __init__(self, *args): pass def __call__(self, x, update=True): - return np.asarray(x) + try: + return np.asarray(x) + except Exception: + raise ValueError("Failed to convert to array", x) def apply_changes(self, other, *args, **kwargs): pass @@ -74,8 +80,10 @@ def copy(self): def push(self, x): x = np.asarray(x) # Unvectorized update of the running statistics. - assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}" - .format(x.shape, self._M.shape)) + if x.shape != self._M.shape: + raise ValueError( + "Unexpected input shape {}, expected {}, value = {}".format( + x.shape, self._M.shape, x)) n1 = self._n self._n += 1 if self._n == 1: diff --git a/python/ray/rllib/utils/policy_client.py b/python/ray/rllib/utils/policy_client.py index 901dc983b0985..1bb4b5e134046 100644 --- a/python/ray/rllib/utils/policy_client.py +++ b/python/ray/rllib/utils/policy_client.py @@ -2,14 +2,17 @@ from __future__ import division from __future__ import print_function +import logging import pickle +logger = logging.getLogger(__name__) + try: import requests # `requests` is not part of stdlib. except ImportError: requests = None - print("Couldn't import `requests` library. Be sure to install it on" - " the client side.") + logger.warn("Couldn't import `requests` library. Be sure to install it on" + " the client side.") class PolicyClient(object): @@ -109,8 +112,7 @@ def _send(self, data): payload = pickle.dumps(data) response = requests.post(self._address, data=payload) if response.status_code != 200: - print("Request failed", data) - print(response.text) + logger.error("Request failed {}: {}".format(response.text, data)) response.raise_for_status() parsed = pickle.loads(response.content) return parsed diff --git a/python/ray/rllib/utils/policy_server.py b/python/ray/rllib/utils/policy_server.py index 13ca376bb82ab..25238971fd139 100644 --- a/python/ray/rllib/utils/policy_server.py +++ b/python/ray/rllib/utils/policy_server.py @@ -18,15 +18,15 @@ class PolicyServer(ThreadingMixIn, HTTPServer): - """REST server than can be launched from a ServingEnv. + """REST server than can be launched from a ExternalEnv. This launches a multi-threaded server that listens on the specified host and port to serve policy requests and forward experiences to RLlib. Examples: - >>> class CartpoleServing(ServingEnv): + >>> class CartpoleServing(ExternalEnv): def __init__(self): - ServingEnv.__init__( + ExternalEnv.__init__( self, spaces.Discrete(2), spaces.Box( low=-10, @@ -50,12 +50,12 @@ def run(self): >>> client.log_returns(eps_id, reward) """ - def __init__(self, serving_env, address, port): - handler = _make_handler(serving_env) + def __init__(self, external_env, address, port): + handler = _make_handler(external_env) HTTPServer.__init__(self, (address, port), handler) -def _make_handler(serving_env): +def _make_handler(external_env): class Handler(SimpleHTTPRequestHandler): def do_POST(self): content_len = int(self.headers.get('Content-Length'), 0) @@ -73,20 +73,20 @@ def execute_command(self, args): command = args["command"] response = {} if command == PolicyClient.START_EPISODE: - response["episode_id"] = serving_env.start_episode( + response["episode_id"] = external_env.start_episode( args["episode_id"], args["training_enabled"]) elif command == PolicyClient.GET_ACTION: - response["action"] = serving_env.get_action( + response["action"] = external_env.get_action( args["episode_id"], args["observation"]) elif command == PolicyClient.LOG_ACTION: - serving_env.log_action(args["episode_id"], args["observation"], - args["action"]) + external_env.log_action(args["episode_id"], + args["observation"], args["action"]) elif command == PolicyClient.LOG_RETURNS: - serving_env.log_returns(args["episode_id"], args["reward"], - args["info"]) + external_env.log_returns(args["episode_id"], args["reward"], + args["info"]) elif command == PolicyClient.END_EPISODE: - serving_env.end_episode(args["episode_id"], - args["observation"]) + external_env.end_episode(args["episode_id"], + args["observation"]) else: raise Exception("Unknown command: {}".format(command)) return response diff --git a/python/ray/rllib/utils/reshaper.py b/python/ray/rllib/utils/reshaper.py deleted file mode 100644 index e9c16521210c4..0000000000000 --- a/python/ray/rllib/utils/reshaper.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np -import tensorflow as tf - - -class Reshaper(object): - """ - This class keeps track of where in the flattened observation space - we should be slicing and what the new shapes should be - """ - - def __init__(self, env_space): - self.shapes = [] - self.slice_positions = [] - self.env_space = env_space - if isinstance(env_space, list): - for space in env_space: - # Handle both gym arrays and just lists of inputs length - if hasattr(space, "n"): - arr_shape = np.asarray([1]) # discrete space - elif hasattr(space, "shape"): - arr_shape = np.asarray(space.shape) - else: - arr_shape = space - self.shapes.append(arr_shape) - if len(self.slice_positions) == 0: - self.slice_positions.append(np.product(arr_shape)) - else: - self.slice_positions.append( - np.product(arr_shape) + self.slice_positions[-1]) - else: - self.shapes.append(np.asarray(env_space.shape)) - self.slice_positions.append(np.product(env_space.shape)) - - def get_slice_lengths(self): - diffed_list = np.diff(self.slice_positions).tolist() - diffed_list.insert(0, self.slice_positions[0]) - return np.asarray(diffed_list).astype(int) - - def split_tensor(self, tensor, axis=-1): - # FIXME (ev) This won't work for mixed action distributions like - # one agent Gaussian one agent discrete - slice_rescale = int(tensor.shape.as_list()[axis] / int( - np.sum(self.get_slice_lengths()))) - return tf.split( - tensor, slice_rescale * self.get_slice_lengths(), axis=axis) - - def split_number(self, number): - slice_rescale = int(number / int(np.sum(self.get_slice_lengths()))) - return slice_rescale * self.get_slice_lengths() diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index 030642ae5b6ae..4359c1b5e546f 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -2,12 +2,15 @@ from __future__ import division from __future__ import print_function +import logging import os import time import tensorflow as tf from tensorflow.python.client import timeline +logger = logging.getLogger(__name__) + class TFRunBuilder(object): """Used to incrementally build up a TensorFlow run. @@ -26,7 +29,8 @@ def __init__(self, session, debug_name): def add_feed_dict(self, feed_dict): assert not self._executed for k in feed_dict: - assert k not in self.feed_dict + if k in self.feed_dict: + raise ValueError("Key added twice: {}".format(k)) self.feed_dict.update(feed_dict) def add_fetches(self, fetches): @@ -41,10 +45,9 @@ def get(self, to_fetch): self._executed = run_timeline( self.session, self.fetches, self.debug_name, self.feed_dict, os.environ.get("TF_TIMELINE_DIR")) - except Exception as e: - print("Error fetching: {}, feed_dict={}".format( + except Exception: + raise ValueError("Error fetching: {}, feed_dict={}".format( self.fetches, self.feed_dict)) - raise e if isinstance(to_fetch, int): return self._executed[to_fetch] elif isinstance(to_fetch, list): @@ -75,8 +78,8 @@ def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): debug_name, os.getpid(), _count)) _count += 1 trace_file = open(outf, "w") - print("Wrote tf timeline ({} s) to {}".format(time.time() - start, - os.path.abspath(outf))) + logger.info("Wrote tf timeline ({} s) to {}".format( + time.time() - start, os.path.abspath(outf))) trace_file.write(trace.generate_chrome_trace_format()) else: fetches = sess.run(ops, feed_dict=feed_dict) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 0826a1387aec8..b84db6757c86a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -20,7 +20,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_client): # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. + # "src/ray/gcs/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) # Filter to clients on the same node and do some basic checking. @@ -89,6 +89,11 @@ def cli(logging_level, logging_format): type=int, help=("If provided, attempt to configure Redis with this " "maximum number of clients.")) +@click.option( + "--redis-password", + required=False, + type=str, + help="If provided, secure Redis ports with this password") @click.option( "--redis-shard-ports", required=False, @@ -100,12 +105,33 @@ def cli(logging_level, logging_format): required=False, type=int, help="the port to use for starting the object manager") +@click.option( + "--node-manager-port", + required=False, + type=int, + help="the port to use for starting the node manager") @click.option( "--object-store-memory", required=False, type=int, help="the maximum amount of memory (in bytes) to allow the " "object store to use") +@click.option( + "--redis-max-memory", + required=False, + type=int, + help=("The max amount of memory (in bytes) to allow redis to use, or None " + "for no limit. Once the limit is exceeded, redis will start LRU " + "eviction of entries. This only applies to the sharded " + "redis tables (task and object tables).")) +@click.option( + "--collect-profiling-data", + default=True, + type=bool, + help=("Whether to collect profiling data. Note that " + "profiling data cannot be LRU evicted, so if you set " + "redis_max_memory then profiling will also be disabled to prevent " + "it from consuming all available redis memory.")) @click.option( "--num-workers", required=False, @@ -162,11 +188,6 @@ def cli(logging_level, logging_format): required=False, type=str, help="the file that contains the autoscaling config") -@click.option( - "--use-raylet", - is_flag=True, - default=None, - help="use the raylet code path") @click.option( "--no-redirect-worker-output", is_flag=True, @@ -177,25 +198,40 @@ def cli(logging_level, logging_format): is_flag=True, default=False, help="do not redirect non-worker stdout and stderr to files") +@click.option( + "--plasma-store-socket-name", + default=None, + help="manually specify the socket name of the plasma store") +@click.option( + "--raylet-socket-name", + default=None, + help="manually specify the socket path of the raylet process") +@click.option( + "--temp-dir", + default=None, + help="manually specify the root temporary dir of the Ray process") +@click.option( + "--internal-config", + default=None, + type=str, + help="Do NOT use this. This is for debugging/development purposes ONLY.") def start(node_ip_address, redis_address, redis_port, num_redis_shards, - redis_max_clients, redis_shard_ports, object_manager_port, - object_store_memory, num_workers, num_cpus, num_gpus, resources, - head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, - use_raylet, no_redirect_worker_output, no_redirect_output): + redis_max_clients, redis_password, redis_shard_ports, + object_manager_port, node_manager_port, object_store_memory, + redis_max_memory, collect_profiling_data, num_workers, num_cpus, + num_gpus, resources, head, no_ui, block, plasma_directory, + huge_pages, autoscaling_config, no_redirect_worker_output, + no_redirect_output, plasma_store_socket_name, raylet_socket_name, + temp_dir, internal_config): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) if redis_address is not None: redis_address = services.address_to_ip(redis_address) - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True - try: resources = json.loads(resources) - except Exception as e: + except Exception: raise Exception("Unable to parse the --resources argument using " "json.loads. Try using a format like\n\n" " --resources='{\"CustomResource1\": 3, " @@ -235,19 +271,15 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, logger.info("Using IP address {} for this node." .format(node_ip_address)) - address_info = {} - # Use the provided object manager port if there is one. - if object_manager_port is not None: - address_info["object_manager_ports"] = [object_manager_port] - if address_info == {}: - address_info = None - address_info = services.start_ray_head( - address_info=address_info, + object_manager_ports=[object_manager_port], + node_manager_ports=[node_manager_port], node_ip_address=node_ip_address, redis_port=redis_port, redis_shard_ports=redis_shard_ports, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, num_workers=num_workers, cleanup=False, redirect_worker_output=not no_redirect_worker_output, @@ -255,26 +287,33 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - redis_protected_mode=False, + redis_password=redis_password, include_webui=(not no_ui), plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=internal_config) logger.info(address_info) logger.info( "\nStarted Ray on this node. You can add additional nodes to " "the cluster by calling\n\n" - " ray start --redis-address {}\n\n" + " ray start --redis-address {}{}{}\n\n" "from the node you wish to add. You can connect a driver to the " "cluster from Python by running\n\n" " import ray\n" - " ray.init(redis_address=\"{}\")\n\n" + " ray.init(redis_address=\"{}{}{}\")\n\n" "If you have trouble connecting from a different machine, check " "that your firewall is configured properly. If you wish to " "terminate the processes that have been started, run\n\n" - " ray stop".format(address_info["redis_address"], - address_info["redis_address"])) + " ray stop".format( + address_info["redis_address"], " --redis-password " + if redis_password else "", redis_password if redis_password + else "", address_info["redis_address"], "\", redis_password=\"" + if redis_password else "", redis_password + if redis_password else "")) else: # Start Ray on a non-head node. if redis_port is not None: @@ -299,10 +338,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. - services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) + services.wait_for_redis_to_start( + redis_ip_address, int(redis_port), password=redis_password) # Create a Redis client. - redis_client = services.create_redis_client(redis_address) + redis_client = services.create_redis_client( + redis_address, password=redis_password) # Check that the verion information on this node matches the version # information that the cluster was started with. @@ -321,15 +362,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, node_ip_address=node_ip_address, redis_address=redis_address, object_manager_ports=[object_manager_port], + node_manager_ports=[node_manager_port], num_workers=num_workers, object_store_memory=object_store_memory, + redis_password=redis_password, cleanup=False, redirect_worker_output=not no_redirect_worker_output, redirect_output=not no_redirect_output, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=internal_config) logger.info(address_info) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" @@ -344,11 +390,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, @cli.command() def stop(): subprocess.call( - [ - "killall global_scheduler plasma_store_server plasma_manager " - "local_scheduler raylet raylet_monitor" - ], - shell=True) + ["killall plasma_store_server raylet raylet_monitor"], shell=True) # Find the PID of the monitor process and kill it. subprocess.call( @@ -373,6 +415,12 @@ def stop(): "grep -v grep | awk '{ print $2 }') 2> /dev/null" ], shell=True) + subprocess.call( + [ + "kill -9 $(ps aux | grep ' ray_' | " + "grep -v grep | awk '{ print $2 }') 2> /dev/null" + ], + shell=True) # Find the PID of the Ray log monitor process and kill it. subprocess.call( @@ -387,10 +435,10 @@ def stop(): from notebook.notebookapp import list_running_servers pids = [ str(server["pid"]) for server in list_running_servers() - if "/tmp/raylogs" in server["notebook_dir"] + if "/tmp/ray" in server["notebook_dir"] ] subprocess.call( - ["kill {} 2> /dev/null".format(" ".join(pids))], shell=True) + ["kill -9 {} 2> /dev/null".format(" ".join(pids))], shell=True) except ImportError: pass @@ -413,24 +461,24 @@ def stop(): "--min-workers", required=False, type=int, - help=("Override the configured min worker node count for the cluster.")) + help="Override the configured min worker node count for the cluster.") @click.option( "--max-workers", required=False, type=int, - help=("Override the configured max worker node count for the cluster.")) + help="Override the configured max worker node count for the cluster.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( "--yes", "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) + help="Don't ask for confirmation.") def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, restart_only, yes, cluster_name): if restart_only or no_restart: @@ -446,19 +494,19 @@ def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, "--workers-only", is_flag=True, default=False, - help=("Only destroy the workers.")) + help="Only destroy the workers.") @click.option( "--yes", "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) + help="Don't ask for confirmation.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def teardown(cluster_config_file, yes, workers_only, cluster_name): teardown_cluster(cluster_config_file, yes, workers_only, cluster_name) @@ -469,17 +517,17 @@ def teardown(cluster_config_file, yes, workers_only, cluster_name): "--start", is_flag=True, default=False, - help=("Start the cluster if needed.")) + help="Start the cluster if needed.") @click.option( - "--tmux", is_flag=True, default=False, help=("Run the command in tmux.")) + "--tmux", is_flag=True, default=False, help="Run the command in tmux.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( - "--new", "-N", is_flag=True, help=("Force creation of a new screen.")) + "--new", "-N", is_flag=True, help="Force creation of a new screen.") def attach(cluster_config_file, start, tmux, cluster_name, new): attach_cluster(cluster_config_file, start, tmux, cluster_name, new) @@ -493,7 +541,7 @@ def attach(cluster_config_file, start, tmux, cluster_name, new): "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def rsync_down(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=True) @@ -507,11 +555,77 @@ def rsync_down(cluster_config_file, source, target, cluster_name): "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def rsync_up(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=False) +@cli.command() +@click.argument("cluster_config_file", required=True, type=str) +@click.option( + "--stop", + is_flag=True, + default=False, + help="Stop the cluster after the command finishes running.") +@click.option( + "--start", + is_flag=True, + default=False, + help="Start the cluster if needed.") +@click.option( + "--screen", + is_flag=True, + default=False, + help="Run the command in a screen.") +@click.option( + "--tmux", is_flag=True, default=False, help="Run the command in tmux.") +@click.option( + "--cluster-name", + "-n", + required=False, + type=str, + help="Override the configured cluster name.") +@click.option( + "--port-forward", required=False, type=int, help="Port to forward.") +@click.argument("script", required=True, type=str) +@click.argument("script_args", required=False, type=str, nargs=-1) +def submit(cluster_config_file, screen, tmux, stop, start, cluster_name, + port_forward, script, script_args): + """Uploads and runs a script on the specified cluster. + + The script is automatically synced to the following location: + + os.path.join("~", os.path.basename(script)) + """ + assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." + + if start: + create_or_update_cluster(cluster_config_file, None, None, False, False, + True, cluster_name) + + target = os.path.join("~", os.path.basename(script)) + rsync(cluster_config_file, script, target, cluster_name, down=False) + + cmd = " ".join(["python", target] + list(script_args)) + exec_cluster(cluster_config_file, cmd, screen, tmux, stop, False, + cluster_name, port_forward) + + if tmux or screen: + attach_command_parts = ["ray attach", cluster_config_file] + if cluster_name is not None: + attach_command_parts.append( + "--cluster-name={}".format(cluster_name)) + if tmux: + attach_command_parts.append("--tmux") + elif screen: + attach_command_parts.append("--screen") + + attach_command = " ".join(attach_command_parts) + attach_info = "Use `{}` to check on command status.".format( + attach_command) + logger.info(attach_info) + + @cli.command() @click.argument("cluster_config_file", required=True, type=str) @click.argument("cmd", required=True, type=str) @@ -519,35 +633,48 @@ def rsync_up(cluster_config_file, source, target, cluster_name): "--stop", is_flag=True, default=False, - help=("Stop the cluster after the command finishes running.")) + help="Stop the cluster after the command finishes running.") @click.option( "--start", is_flag=True, default=False, - help=("Start the cluster if needed.")) + help="Start the cluster if needed.") @click.option( "--screen", is_flag=True, default=False, - help=("Run the command in a screen.")) + help="Run the command in a screen.") @click.option( - "--tmux", is_flag=True, default=False, help=("Run the command in tmux.")) + "--tmux", is_flag=True, default=False, help="Run the command in tmux.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( - "--port-forward", required=False, type=int, help=("Port to forward.")) + "--port-forward", required=False, type=int, help="Port to forward.") def exec_cmd(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, port_forward): assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." + exec_cluster(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, port_forward) - if tmux: - logger.info("Use `ray attach {} --tmux` " - "to check on command status.".format(cluster_config_file)) + + if tmux or screen: + attach_command_parts = ["ray attach", cluster_config_file] + if cluster_name is not None: + attach_command_parts.append( + "--cluster-name={}".format(cluster_name)) + if tmux: + attach_command_parts.append("--tmux") + elif screen: + attach_command_parts.append("--screen") + + attach_command = " ".join(attach_command_parts) + attach_info = "Use `{}` to check on command status.".format( + attach_command) + logger.info(attach_info) @cli.command() @@ -557,22 +684,49 @@ def exec_cmd(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def get_head_ip(cluster_config_file, cluster_name): click.echo(get_head_node_ip(cluster_config_file, cluster_name)) +@cli.command() +def stack(): + COMMAND = """ +pyspy=`which py-spy` +if [ ! -e "$pyspy" ]; then + echo "ERROR: Please 'pip install py-spy' (or ray[debug]) first" + exit 1 +fi +# Set IFS to iterate over lines instead of over words. +export IFS=" +" +# Call sudo to prompt for password before anything has been printed. +sudo true +workers=$( + ps aux | grep ' ray_' | grep -v grep +) +for worker in $workers; do + echo "Stack dump for $worker"; + pid=`echo $worker | awk '{print $2}'`; + sudo $pyspy --pid $pid --dump; + echo; +done + """ + subprocess.call(COMMAND, shell=True) + + cli.add_command(start) cli.add_command(stop) -cli.add_command(create_or_update) cli.add_command(create_or_update, name="up") cli.add_command(attach) cli.add_command(exec_cmd, name="exec") -cli.add_command(rsync_down) -cli.add_command(rsync_up) +cli.add_command(rsync_down, name="rsync_down") +cli.add_command(rsync_up, name="rsync_up") +cli.add_command(submit) cli.add_command(teardown) cli.add_command(teardown, name="down") -cli.add_command(get_head_ip) +cli.add_command(get_head_ip, name="get_head_ip") +cli.add_command(stack) def main(): diff --git a/python/ray/services.py b/python/ray/services.py index 3a421437c5664..e96196b5f9463 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -2,43 +2,43 @@ from __future__ import division from __future__ import print_function -import binascii import json import logging import multiprocessing import os import random import resource -import shutil import signal import socket import subprocess import sys import threading import time -from collections import OrderedDict, namedtuple -from datetime import datetime - +from collections import OrderedDict import redis import pyarrow # Ray modules import ray.ray_constants -import ray.global_scheduler as global_scheduler -import ray.local_scheduler import ray.plasma +from ray.tempfile_services import ( + get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name, + get_temp_root, new_log_monitor_log_file, new_monitor_log_file, + new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file, + new_webui_log_file, set_temp_root) + PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_LOG_MONITOR = "log_monitor" PROCESS_TYPE_WORKER = "worker" PROCESS_TYPE_RAYLET = "raylet" -PROCESS_TYPE_LOCAL_SCHEDULER = "local_scheduler" -PROCESS_TYPE_PLASMA_MANAGER = "plasma_manager" PROCESS_TYPE_PLASMA_STORE = "plasma_store" -PROCESS_TYPE_GLOBAL_SCHEDULER = "global_scheduler" PROCESS_TYPE_REDIS_SERVER = "redis_server" PROCESS_TYPE_WEB_UI = "web_ui" +# Max bytes to allocate to plasma unless overriden by the user +MAX_DEFAULT_MEM = 20 * 1000 * 1000 * 1000 + # This is a dictionary tracking all of the processes of different types that # have been started by this services module. Note that the order of the keys is # important because it determines the order in which these processes will be @@ -47,23 +47,20 @@ all_processes = OrderedDict( [(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []), (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []), - (PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []), - (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []), - (PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], ) + (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_REDIS_SERVER, []), + (PROCESS_TYPE_WEB_UI, [])], ) # True if processes are run in the valgrind profiler. RUN_RAYLET_PROFILER = False -RUN_LOCAL_SCHEDULER_PROFILER = False -RUN_PLASMA_MANAGER_PROFILER = False RUN_PLASMA_STORE_PROFILER = False # Location of the redis server and module. REDIS_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/thirdparty/redis/src/redis-server") + "core/src/ray/thirdparty/redis/src/redis-server") REDIS_MODULE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/redis_module/libray_redis_module.so") + "core/src/ray/gcs/redis_module/libray_redis_module.so") # Location of the credis server and modules. # credis will be enabled if the environment variable RAY_USE_NEW_GCS is set. @@ -84,14 +81,6 @@ RAYLET_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet") -# ObjectStoreAddress tuples contain all information necessary to connect to an -# object store. The fields are: -# - name: The socket name for the object store -# - manager_name: The socket name for the object store manager -# - manager_port: The Internet port that the object store manager listens on -ObjectStoreAddress = namedtuple("ObjectStoreAddress", - ["name", "manager_name", "manager_port"]) - # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray configures it by default automatically # using logging.basicConfig in its entry/init points. @@ -120,10 +109,6 @@ def new_port(): return random.randint(10000, 65535) -def random_name(): - return str(random.randint(0, 99999999)) - - def kill_process(p): """Kill a process. @@ -136,10 +121,7 @@ def kill_process(p): if p.poll() is not None: # The process has already terminated. return True - if any([ - RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER, - RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER - ]): + if any([RUN_RAYLET_PROFILER, RUN_PLASMA_STORE_PROFILER]): # Give process signal to write profiler data. os.kill(p.pid, signal.SIGINT) # Wait for profiling data to be written. @@ -187,18 +169,21 @@ def cleanup(): logger.warning("Ray did not shut down properly.") -def all_processes_alive(exclude=[]): +def all_processes_alive(exclude=None): """Check if all of the processes are still alive. Args: exclude: Don't check the processes whose types are in this list. """ + + if exclude is None: + exclude = [] for process_type, processes in all_processes.items(): # Note that p.poll() returns the exit code that the process exited # with, so an exit code of None indicates that the process is still # alive. processes_alive = [p.poll() is None for p in processes] - if (not all(processes_alive) and process_type not in exclude): + if not all(processes_alive) and process_type not in exclude: logger.warning( "A process of type {} has died.".format(process_type)) return False @@ -258,7 +243,10 @@ def get_node_ip_address(address="8.8.8.8:53"): return node_ip_address -def record_log_files_in_redis(redis_address, node_ip_address, log_files): +def record_log_files_in_redis(redis_address, + node_ip_address, + log_files, + password=None): """Record in Redis that a new log file has been created. This is used so that each log monitor can check Redis and figure out which @@ -270,23 +258,24 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): on. log_files: A list of file handles for the log files. If one of the file handles is None, we ignore it. + password (str): The password of the redis server. """ for log_file in log_files: if log_file is not None: redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=password) # The name of the key storing the list of log filenames for this IP # address. log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) redis_client.rpush(log_file_list_key, log_file.name) -def create_redis_client(redis_address): +def create_redis_client(redis_address, password=None): """Create a Redis client. Args: - The IP address and port of the Redis server. + The IP address, port, and password of the Redis server. Returns: A Redis client. @@ -294,10 +283,14 @@ def create_redis_client(redis_address): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine # as Redis) must have run "CONFIG SET protected-mode no". - return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + return redis.StrictRedis( + host=redis_ip_address, port=int(redis_port), password=password) -def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): +def wait_for_redis_to_start(redis_ip_address, + redis_port, + password=None, + num_retries=5): """Wait for a Redis server to be available. This is accomplished by creating a Redis client and sending a random @@ -306,13 +299,15 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): Args: redis_ip_address (str): The IP address of the redis server. redis_port (int): The port of the redis server. + password (str): The password of the redis server. num_retries (int): The number of times to try connecting with redis. The client will sleep for one second between attempts. Raises: Exception: An exception is raised if we could not connect with Redis. """ - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + redis_client = redis.StrictRedis( + host=redis_ip_address, port=redis_port, password=password) # Wait for the Redis server to start. counter = 0 while counter < num_retries: @@ -322,7 +317,7 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): "Waiting for redis server at {}:{} to respond...".format( redis_ip_address, redis_port)) redis_client.client_list() - except redis.ConnectionError as e: + except redis.ConnectionError: # Wait a little bit. time.sleep(1) logger.info("Failed to connect to the redis server, retrying.") @@ -358,7 +353,7 @@ def _compute_version_info(): ray_version = ray.__version__ python_version = ".".join(map(str, sys.version_info[:3])) pyarrow_version = pyarrow.__version__ - return (ray_version, python_version, pyarrow_version) + return ray_version, python_version, pyarrow_version def _put_version_info_in_redis(redis_client): @@ -417,12 +412,12 @@ def start_redis(node_ip_address, redis_shard_ports=None, num_redis_shards=1, redis_max_clients=None, - use_raylet=False, redirect_output=False, redirect_worker_output=False, cleanup=True, - protected_mode=False, - use_credis=None): + password=None, + use_credis=None, + redis_max_memory=None): """Start the Redis global state store. Args: @@ -437,8 +432,6 @@ def start_redis(node_ip_address, shard. redis_max_clients: If this is provided, Ray will attempt to configure Redis with this maxclients number. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. redirect_output (bool): True if output should be redirected to a file and false otherwise. redirect_worker_output (bool): True if worker output should be @@ -448,16 +441,21 @@ def start_redis(node_ip_address, then all Redis processes started by this method will be killed by services.cleanup() when the Python process that imported services exits. + password (str): Prevents external clients without the password + from connecting to Redis if provided. use_credis: If True, additionally load the chain-replicated libraries into the redis servers. Defaults to None, which means its value is set by the presence of "RAY_USE_NEW_GCS" in os.environ. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). Returns: A tuple of the address for the primary Redis shard and a list of addresses for the remaining shards. """ - redis_stdout_file, redis_stderr_file = new_log_files( - "redis", redirect_output) + redis_stdout_file, redis_stderr_file = new_redis_log_file(redirect_output) if redis_shard_ports is None: redis_shard_ports = num_redis_shards * [None] @@ -467,6 +465,13 @@ def start_redis(node_ip_address, if use_credis is None: use_credis = ("RAY_USE_NEW_GCS" in os.environ) + if use_credis and password is not None: + # TODO(pschafhalter) remove this once credis supports + # authenticating Redis ports + raise Exception("Setting the `redis_password` argument is not " + "supported in credis. To run Ray with " + "password-protected Redis ports, ensure that " + "the environment variable `RAY_USE_NEW_GCS=off`.") if not use_credis: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -475,7 +480,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + password=password, + redis_max_memory=None) else: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -484,25 +490,22 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former # supplies. - modules=[CREDIS_MASTER_MODULE, REDIS_MODULE]) + modules=[CREDIS_MASTER_MODULE, REDIS_MODULE], + password=password, + redis_max_memory=None) if port is not None: assert assigned_port == port port = assigned_port redis_address = address(node_ip_address, port) - redis_client = redis.StrictRedis(host=node_ip_address, port=port) - - # Store whether we're using the raylet code path or not. - redis_client.set("UseRaylet", 1 if use_raylet else 0) - # Register the number of Redis shards in the primary shard, so that clients # know how many redis shards to expect under RedisShards. - primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port) + primary_redis_client = redis.StrictRedis( + host=node_ip_address, port=port, password=password) primary_redis_client.set("NumRedisShards", str(num_redis_shards)) # Put the redirect_worker_output bool in the Redis shard so that workers @@ -517,8 +520,8 @@ def start_redis(node_ip_address, # prefixed by "redis-". redis_shards = [] for i in range(num_redis_shards): - redis_stdout_file, redis_stderr_file = new_log_files( - "redis-{}".format(i), redirect_output) + redis_stdout_file, redis_stderr_file = new_redis_log_file( + redirect_output, shard_number=i) if not use_credis: redis_shard_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -527,7 +530,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + password=password, + redis_max_memory=redis_max_memory) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ @@ -539,12 +543,13 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, + password=password, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray # module, as the latter contains an extern declaration that the # former supplies. - modules=[CREDIS_MEMBER_MODULE, REDIS_MODULE]) + modules=[CREDIS_MEMBER_MODULE, REDIS_MODULE], + redis_max_memory=redis_max_memory) if redis_shard_ports[i] is not None: assert redis_shard_port == redis_shard_ports[i] @@ -555,7 +560,7 @@ def start_redis(node_ip_address, if use_credis: shard_client = redis.StrictRedis( - host=node_ip_address, port=redis_shard_port) + host=node_ip_address, port=redis_shard_port, password=password) # Configure the chain state. primary_redis_client.execute_command("MASTER.ADD", node_ip_address, redis_shard_port) @@ -565,22 +570,6 @@ def start_redis(node_ip_address, return redis_address, redis_shards -def _make_temp_redis_config(node_ip_address): - """Create a configuration file for Redis. - - Args: - node_ip_address: The IP address of this node. This should not be - 127.0.0.1. - """ - redis_config_name = "/tmp/redis_conf{}".format(random_name()) - with open(redis_config_name, 'w') as f: - # This allows redis clients on the same machine to connect using the - # node's IP address as opposed to just 127.0.0.1. This is only relevant - # when the server is in protected mode. - f.write("bind 127.0.0.1 {}".format(node_ip_address)) - return redis_config_name - - def _start_redis_instance(node_ip_address="127.0.0.1", port=None, redis_max_clients=None, @@ -588,9 +577,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file=None, stderr_file=None, cleanup=True, - protected_mode=False, + password=None, executable=REDIS_EXECUTABLE, - modules=None): + modules=None, + redis_max_memory=None): """Start a single Redis server. Args: @@ -608,14 +598,15 @@ def _start_redis_instance(node_ip_address="127.0.0.1", cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. - protected_mode: True if we should start the Redis server in protected - mode. This will prevent clients on other machines from connecting - and is only used when the Redis servers are started via ray.init() - as opposed to ray start. + password (str): Prevents external clients without the password + from connecting to Redis if provided. executable (str): Full path tho the redis-server executable. modules (list of str): A list of pathnames, pointing to the redis module(s) that will be loaded in this redis server. If None, load the default Ray redis module. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. Returns: A tuple of the port used by Redis and a handle to the process that was @@ -637,9 +628,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", else: port = new_port() - if protected_mode: - redis_config_filename = _make_temp_redis_config(node_ip_address) - load_module_args = [] for module in modules: load_module_args += ["--loadmodule", module] @@ -650,8 +638,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # Construct the command to start the Redis server. command = [executable] - if protected_mode: - command += [redis_config_filename] + if password: + command += ["--requirepass", password] command += ( ["--port", str(port), "--loglevel", "warning"] + load_module_args) @@ -670,17 +658,25 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file.name, stderr_file.name)) # Create a Redis client just for configuring Redis. - redis_client = redis.StrictRedis(host="127.0.0.1", port=port) + redis_client = redis.StrictRedis( + host="127.0.0.1", port=port, password=password) # Wait for the Redis server to start. - wait_for_redis_to_start("127.0.0.1", port) + wait_for_redis_to_start("127.0.0.1", port, password=password) # Configure Redis to generate keyspace notifications. TODO(rkn): Change # this to only generate notifications for the export keys. redis_client.config_set("notify-keyspace-events", "Kl") # Configure Redis to not run in protected mode so that processes on other # hosts can connect to it. TODO(rkn): Do this in a more secure way. - if not protected_mode: - redis_client.config_set("protected-mode", "no") + redis_client.config_set("protected-mode", "no") + + # Discard old task and object metadata. + if redis_max_memory is not None: + redis_client.config_set("maxmemory", str(redis_max_memory)) + redis_client.config_set("maxmemory-policy", "allkeys-lru") + redis_client.config_set("maxmemory-samples", "10") + logger.info("Starting Redis shard with {} GB max memory.".format( + round(redis_max_memory / 1e9, 2))) # If redis_max_clients is provided, attempt to raise the number of maximum # number of Redis clients. @@ -717,8 +713,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1", redis_client.set("redis_start_time", time.time()) # Record the log files in Redis. record_log_files_in_redis( - address(node_ip_address, port), node_ip_address, - [stdout_file, stderr_file]) + address(node_ip_address, port), + node_ip_address, [stdout_file, stderr_file], + password=password) return port, p @@ -726,7 +723,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=cleanup): + cleanup=cleanup, + redis_password=None): """Start a log monitor process. Args: @@ -740,50 +738,23 @@ def start_log_monitor(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by services.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. """ log_monitor_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") - p = subprocess.Popen( - [ - sys.executable, "-u", log_monitor_filepath, "--redis-address", - redis_address, "--node-ip-address", node_ip_address - ], - stdout=stdout_file, - stderr=stderr_file) + command = [ + sys.executable, "-u", log_monitor_filepath, "--redis-address", + redis_address, "--node-ip-address", node_ip_address + ] + if redis_password: + command += ["--redis-password", redis_password] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) - - -def start_global_scheduler(redis_address, - node_ip_address, - stdout_file=None, - stderr_file=None, - cleanup=True): - """Start a global scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will - run on. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. - """ - p = global_scheduler.start_global_scheduler( + record_log_files_in_redis( redis_address, - node_ip_address, - stdout_file=stdout_file, - stderr_file=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): @@ -799,15 +770,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): then this process will be killed by services.cleanup() when the Python process that imported services exits. """ - new_env = os.environ.copy() - notebook_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb") - # We copy the notebook file so that the original doesn't get modified by - # the user. - random_ui_id = random.randint(0, 100000) - new_notebook_filepath = "/tmp/raylogs/ray_ui{}.ipynb".format(random_ui_id) - new_notebook_directory = os.path.dirname(new_notebook_filepath) - shutil.copy(notebook_filepath, new_notebook_filepath) + port = 8888 while True: try: @@ -821,7 +784,8 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): new_env["REDIS_ADDRESS"] = redis_address # We generate the token used for authentication ourselves to avoid # querying the jupyter server. - token = ray.utils.decode(binascii.hexlify(os.urandom(24))) + new_notebook_directory, webui_url, token = ( + get_ipython_notebook_path(port)) # The --ip=0.0.0.0 flag is intended to enable connecting to a notebook # running within a docker container (from the outside). command = [ @@ -847,21 +811,17 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): else: if cleanup: all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) - webui_url = ("http://localhost:{}/notebooks/ray_ui{}.ipynb?token={}" - .format(port, random_ui_id, token)) logger.info("\n" + "=" * 70) logger.info("View the web UI at {}".format(webui_url)) logger.info("=" * 70 + "\n") return webui_url -def check_and_update_resources(resources, use_raylet): +def check_and_update_resources(resources): """Sanity check a resource dictionary and add sensible defaults. Args: resources: A dictionary mapping resource names to resource quantities. - use_raylet: True if we are using the raylet code path and false - otherwise. Returns: A new resource dictionary. @@ -900,86 +860,30 @@ def check_and_update_resources(resources, use_raylet): and not resource_quantity.is_integer()): raise ValueError("Resource quantities must all be whole numbers.") - if (use_raylet and - resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY): + if resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY: raise ValueError("Resource quantities must be at most {}.".format( ray.ray_constants.MAX_RESOURCE_QUANTITY)) return resources -def start_local_scheduler(redis_address, - node_ip_address, - plasma_store_name, - plasma_manager_name, - worker_path, - plasma_address=None, - stdout_file=None, - stderr_file=None, - cleanup=True, - resources=None, - num_workers=0): - """Start a local scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address (str): The IP address of the node that this local - scheduler is running on. - plasma_store_name (str): The name of the plasma store socket to connect - to. - plasma_manager_name (str): The name of the plasma manager socket to - connect to. - worker_path (str): The path of the script to use when the local - scheduler starts up new workers. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by serices.cleanup() when the - Python process that imported services exits. - resources: A dictionary mapping the name of a resource to the available - quantity of that resource. - num_workers (int): The number of workers that the local scheduler - should start. - - Return: - The name of the local scheduler socket. - """ - resources = check_and_update_resources(resources, False) - - logger.info("Starting local scheduler with the following resources: {}." - .format(resources)) - local_scheduler_name, p = ray.local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name, - worker_path=worker_path, - node_ip_address=node_ip_address, - redis_address=redis_address, - plasma_address=plasma_address, - use_profiler=RUN_LOCAL_SCHEDULER_PROFILER, - stdout_file=stdout_file, - stderr_file=stderr_file, - static_resources=resources, - num_workers=num_workers) - if cleanup: - all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) - return local_scheduler_name - - def start_raylet(redis_address, node_ip_address, + raylet_name, plasma_store_name, worker_path, resources=None, + object_manager_port=None, + node_manager_port=None, num_workers=0, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + config=None, + redis_password=None, + collect_profiling_data=True): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -988,8 +892,16 @@ def start_raylet(redis_address, scheduler is running on. plasma_store_name (str): The name of the plasma store socket to connect to. + raylet_name (str): The name of the raylet socket to create. worker_path (str): The path of the script to use when the local scheduler starts up new workers. + resources: The resources that this raylet has. + object_manager_port (int): The port to use for the object manager. If + this is not provided, we will use 0 and the object manager will + choose its own port. + node_manager_port (int): The port to use for the node manager. If + this is not provided, we will use 0 and the node manager will + choose its own port. use_valgrind (bool): True if the raylet should be started inside of valgrind. If this is True, use_profiler must be False. use_profiler (bool): True if the raylet should be started inside @@ -1001,14 +913,21 @@ def start_raylet(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. + config (dict|None): Optional Raylet configuration that will + override defaults in RayConfig. + redis_password (str): The password of the redis server. + collect_profiling_data: Whether to collect profiling data from workers. Returns: The raylet socket name. """ + config = config or {} + config_str = ",".join(["{},{}".format(*kv) for kv in config.items()]) + if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - static_resources = check_and_update_resources(resources, True) + static_resources = check_and_update_resources(resources) # Limit the number of workers that can be started in parallel by the # raylet. However, make sure it is at least 1. @@ -1016,36 +935,52 @@ def start_raylet(redis_address, 1, min(multiprocessing.cpu_count(), static_resources["CPU"])) # Format the resource argument in a form like 'CPU,1.0,GPU,0,Custom,3'. - resource_argument = ",".join([ - "{},{}".format(resource_name, resource_value) - for resource_name, resource_value in zip(static_resources.keys(), - static_resources.values()) - ]) + resource_argument = ",".join( + ["{},{}".format(*kv) for kv in static_resources.items()]) gcs_ip_address, gcs_port = redis_address.split(":") - raylet_name = "/tmp/raylet{}".format(random_name()) # Create the command that the Raylet will use to start workers. start_worker_command = ("{} {} " "--node-ip-address={} " "--object-store-name={} " "--raylet-name={} " - "--redis-address={}".format( + "--redis-address={} " + "--collect-profiling-data={} " + "--temp-dir={}".format( sys.executable, worker_path, node_ip_address, - plasma_store_name, raylet_name, redis_address)) + plasma_store_name, raylet_name, redis_address, + "1" if collect_profiling_data else "0", + get_temp_root())) + if redis_password: + start_worker_command += " --redis-password {}".format(redis_password) + + # If the object manager port is None, then use 0 to cause the object + # manager to choose its own port. + if object_manager_port is None: + object_manager_port = 0 + # If the node manager port is None, then use 0 to cause the node manager + # to choose its own port. + if node_manager_port is None: + node_manager_port = 0 command = [ RAYLET_EXECUTABLE, raylet_name, plasma_store_name, + str(object_manager_port), + str(node_manager_port), node_ip_address, gcs_ip_address, gcs_port, str(num_workers), str(maximum_startup_concurrency), resource_argument, + config_str, start_worker_command, "", # Worker command for Java, not needed for Python. + redis_password or "", + get_temp_root(), ] if use_valgrind: @@ -1062,29 +997,114 @@ def start_raylet(redis_address, ["valgrind", "--tool=callgrind"] + command, stdout=stdout_file, stderr=stderr_file) + elif "RAYLET_PERFTOOLS_PATH" in os.environ: + modified_env = os.environ.copy() + modified_env["LD_PRELOAD"] = os.environ["RAYLET_PERFTOOLS_PATH"] + modified_env["CPUPROFILE"] = os.environ["RAYLET_PERFTOOLS_LOGFILE"] + pid = subprocess.Popen( + command, stdout=stdout_file, stderr=stderr_file, env=modified_env) else: pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_RAYLET].append(pid) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) return raylet_name +def determine_plasma_store_config(object_store_memory=None, + plasma_directory=None, + huge_pages=False): + """Figure out how to configure the plasma object store. + + This will determine which directory to use for the plasma store (e.g., + /tmp or /dev/shm) and how much memory to start the store with. On Linux, + we will try to use /dev/shm unless the shared memory file system is too + small, in which case we will fall back to /tmp. If any of the object store + memory or plasma directory parameters are specified by the user, then those + values will be preserved. + + Args: + object_store_memory (int): The user-specified object store memory + parameter. + plasma_directory (str): The user-specified plasma directory parameter. + huge_pages (bool): The user-specified huge pages parameter. + + Returns: + A tuple of the object store memory to use and the plasma directory to + use. If either of these values is specified by the user, then that + value will be preserved. + """ + system_memory = ray.utils.get_system_memory() + + # Choose a default object store size. + if object_store_memory is None: + object_store_memory = int(system_memory * 0.4) + # Cap memory to avoid memory waste and perf issues on large nodes + if object_store_memory > MAX_DEFAULT_MEM: + logger.warning( + "Warning: Capping object memory store to {}GB. ".format( + MAX_DEFAULT_MEM // 1e9) + + "To increase this further, specify `object_store_memory` " + "when calling ray.init() or ray start.") + object_store_memory = MAX_DEFAULT_MEM + + # Determine which directory to use. By default, use /tmp on MacOS and + # /dev/shm on Linux, unless the shared-memory file system is too small, + # in which case we default to /tmp on Linux. + if plasma_directory is None: + if sys.platform == "linux" or sys.platform == "linux2": + shm_avail = ray.utils.get_shared_memory_bytes() + # Compare the requested memory size to the memory available in + # /dev/shm. + if shm_avail > object_store_memory: + plasma_directory = "/dev/shm" + else: + plasma_directory = "/tmp" + logger.warning( + "WARNING: The object store is using /tmp instead of " + "/dev/shm because /dev/shm has only {} bytes available. " + "This may slow down performance! You may be able to free " + "up space by deleting files in /dev/shm or terminating " + "any running plasma_store_server processes. If you are " + "inside a Docker container, you may need to pass an " + "argument with the flag '--shm-size' to 'docker run'." + .format(shm_avail)) + else: + plasma_directory = "/tmp" + + # Do some sanity checks. + if object_store_memory > system_memory: + raise Exception( + "The requested object store memory size is greater " + "than the total available memory.") + else: + plasma_directory = os.path.abspath(plasma_directory) + logger.warning("WARNING: object_store_memory is not verified when " + "plasma_directory is set.") + + if not os.path.isdir(plasma_directory): + raise Exception("The file {} does not exist or is not a directory." + .format(plasma_directory)) + + return object_store_memory, plasma_directory + + def start_plasma_store(node_ip_address, redis_address, object_manager_port=None, store_stdout_file=None, store_stderr_file=None, - manager_stdout_file=None, - manager_stderr_file=None, - objstore_memory=None, + object_store_memory=None, cleanup=True, plasma_directory=None, huge_pages=False, - use_raylet=False): + plasma_store_socket_name=None, + redis_password=None): """This method starts an object store process. Args: @@ -1097,14 +1117,8 @@ def start_plasma_store(node_ip_address, to. If no redirection should happen, then this should be None. store_stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - manager_stdout_file: A file handle opened for writing to redirect - stdout to. If no redirection should happen, then this should be - None. - manager_stderr_file: A file handle opened for writing to redirect - stderr to. If no redirection should happen, then this should be - None. - objstore_memory: The amount of memory (in bytes) to start the object - store with. + object_store_memory: The amount of memory (in bytes) to start the + object store with. cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. @@ -1112,97 +1126,40 @@ def start_plasma_store(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + redis_password (str): The password of the redis server. Return: - A tuple of the Plasma store socket name, the Plasma manager socket - name, and the plasma manager port. + The Plasma store socket name. """ - if objstore_memory is None: - # Compute a fraction of the system memory for the Plasma store to use. - system_memory = ray.utils.get_system_memory() - if sys.platform == "linux" or sys.platform == "linux2": - # On linux we use /dev/shm, its size is half the size of the - # physical memory. To not overflow it, we set the plasma memory - # limit to 0.4 times the size of the physical memory. - objstore_memory = int(system_memory * 0.4) - # Compare the requested memory size to the memory available in - # /dev/shm. - shm_fd = os.open("/dev/shm", os.O_RDONLY) - try: - shm_fs_stats = os.fstatvfs(shm_fd) - # The value shm_fs_stats.f_bsize is the block size and the - # value shm_fs_stats.f_bavail is the number of available - # blocks. - shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail - if objstore_memory > shm_avail: - logger.warning( - "Warning: Reducing object store memory because " - "/dev/shm has only {} bytes available. You may be " - "able to free up space by deleting files in " - "/dev/shm. If you are inside a Docker container, " - "you may need to pass an argument with the flag " - "'--shm-size' to 'docker run'.".format(shm_avail)) - objstore_memory = int(shm_avail * 0.8) - finally: - os.close(shm_fd) - else: - objstore_memory = int(system_memory * 0.8) + object_store_memory, plasma_directory = determine_plasma_store_config( + object_store_memory, plasma_directory, huge_pages) + + # Print the object store memory using two decimal places. + object_store_memory_str = (object_store_memory / 10**7) / 10**2 + logger.info("Starting the Plasma object store with {} GB memory " + "using {}.".format(object_store_memory_str, plasma_directory)) # Start the Plasma store. - logger.info("Starting the Plasma object store with {0:.2f} GB memory." - .format(objstore_memory // 10**9)) plasma_store_name, p1 = ray.plasma.start_plasma_store( - plasma_store_memory=objstore_memory, + plasma_store_memory=object_store_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, stdout_file=store_stdout_file, stderr_file=store_stderr_file, plasma_directory=plasma_directory, - huge_pages=huge_pages) - # Start the plasma manager. - if not use_raylet: - if object_manager_port is not None: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - plasma_manager_port=object_manager_port, - node_ip_address=node_ip_address, - num_retries=1, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - assert plasma_manager_port == object_manager_port - else: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - node_ip_address=node_ip_address, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - else: - plasma_manager_port = None - plasma_manager_name = None + huge_pages=huge_pages, + socket_name=plasma_store_socket_name) if cleanup: all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) - record_log_files_in_redis(redis_address, node_ip_address, - [store_stdout_file, store_stderr_file]) - if not use_raylet: - if cleanup: - all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) - record_log_files_in_redis(redis_address, node_ip_address, - [manager_stdout_file, manager_stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [store_stdout_file, store_stderr_file], + password=redis_password) - return ObjectStoreAddress(plasma_store_name, plasma_manager_name, - plasma_manager_port) + return plasma_store_name def start_worker(node_ip_address, object_store_name, - object_store_manager_name, local_scheduler_name, redis_address, worker_path, @@ -1215,7 +1172,6 @@ def start_worker(node_ip_address, node_ip_address (str): The IP address of the node that this worker is running on. object_store_name (str): The name of the object store. - object_store_manager_name (str): The name of the object store manager. local_scheduler_name (str): The name of the local scheduler. redis_address (str): The address that the Redis server is listening on. worker_path (str): The path of the source code which the worker process @@ -1233,9 +1189,8 @@ def start_worker(node_ip_address, sys.executable, "-u", worker_path, "--node-ip-address=" + node_ip_address, "--object-store-name=" + object_store_name, - "--object-store-manager-name=" + object_store_manager_name, - "--local-scheduler-name=" + local_scheduler_name, - "--redis-address=" + str(redis_address) + "--redis-address=" + str(redis_address), + "--temp-dir=" + get_temp_root() ] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: @@ -1249,7 +1204,8 @@ def start_monitor(redis_address, stdout_file=None, stderr_file=None, cleanup=True, - autoscaling_config=None): + autoscaling_config=None, + redis_password=None): """Run a process to monitor the other processes. Args: @@ -1265,6 +1221,7 @@ def start_monitor(redis_address, Python process that imported services exits. This is True by default. autoscaling_config: path to autoscaling config file. + redis_password (str): The password of the redis server. """ monitor_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "monitor.py") @@ -1274,17 +1231,23 @@ def start_monitor(redis_address, ] if autoscaling_config: command.append("--autoscaling-config=" + str(autoscaling_config)) + if redis_password: + command.append("--redis-password=" + redis_password) p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_raylet_monitor(redis_address, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None, + config=None): """Run a process to monitor the other processes. Args: @@ -1297,29 +1260,40 @@ def start_raylet_monitor(redis_address, then this process will be killed by services.cleanup() when the Python process that imported services exits. This is True by default. + redis_password (str): The password of the redis server. + config (dict|None): Optional configuration that will + override defaults in RayConfig. """ gcs_ip_address, gcs_port = redis_address.split(":") - command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port] + redis_password = redis_password or "" + config = config or {} + config_str = ",".join(["{},{}".format(*kv) for kv in config.items()]) + command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port, config_str] + if redis_password: + command += [redis_password] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_MONITOR].append(p) def start_ray_processes(address_info=None, + object_manager_ports=None, + node_manager_ports=None, node_ip_address="127.0.0.1", redis_port=None, redis_shard_ports=None, num_workers=None, num_local_schedulers=1, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, num_redis_shards=1, redis_max_clients=None, - redis_protected_mode=False, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, redirect_output=False, - include_global_scheduler=False, include_log_monitor=False, include_webui=False, start_workers_from_local_scheduler=True, @@ -1327,13 +1301,22 @@ def start_ray_processes(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False): + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None, + _internal_config=None): """Helper method to start Ray processes. Args: address_info (dict): A dictionary with address information for processes that have already been started. If provided, address_info will be modified to include processes that are newly started. + object_manager_ports (list): A list of the ports to use for the object + managers. There should be one per object manager being started on + this node (typically just one). + node_manager_ports (list): A list of the ports to use for the node + managers. There should be one per node manager being started on + this node (typically just one). node_ip_address (str): The IP address of this node. redis_port (int): The port that the primary Redis shard should listen to. If None, then a random port will be chosen. If the key @@ -1350,13 +1333,20 @@ def start_ray_processes(address_info=None, address_info. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data. Note that + profiling data cannot be LRU evicted, so if you set + redis_max_memory then profiling will also be disabled to prevent + it from consuming all available redis memory. num_redis_shards: The number of Redis shards to start in addition to the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. - redis_protected_mode: True if we should start Redis in protected mode. - This will prevent clients from other machines from connecting and - is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1366,8 +1356,6 @@ def start_ray_processes(address_info=None, processes should be redirected to files. redirect_output (bool): True if stdout and stderr for non-worker processes should be redirected to files and false otherwise. - include_global_scheduler (bool): If include_global_scheduler is True, - then start a global scheduler process. include_log_monitor (bool): If True, then start a log monitor to monitor the log files for all processes on this node and push their contents to Redis. @@ -1383,15 +1371,26 @@ def start_ray_processes(address_info=None, huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. + _internal_config (str): JSON configuration for overriding + RayConfig defaults. For testing purposes ONLY. Returns: A dictionary of the address information for the processes that were started. """ - logger.info( - "Process STDOUT and STDERR is being redirected to /tmp/raylogs/.") + + set_temp_root(temp_dir) + + logger.info("Process STDOUT and STDERR is being redirected to {}.".format( + get_logs_dir_path())) + + config = json.loads(_internal_config) if _internal_config else None if resources is None: resources = {} @@ -1399,7 +1398,8 @@ def start_ray_processes(address_info=None, resources = num_local_schedulers * [resources] if num_workers is not None: - workers_per_local_scheduler = num_local_schedulers * [num_workers] + raise Exception("The 'num_workers' argument is deprecated. Please use " + "'num_cpus' instead.") else: workers_per_local_scheduler = [] for resource_dict in resources: @@ -1429,190 +1429,116 @@ def start_ray_processes(address_info=None, redis_shard_ports=redis_shard_ports, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - use_raylet=use_raylet, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, - protected_mode=redis_protected_mode) + password=redis_password, + redis_max_memory=redis_max_memory) address_info["redis_address"] = redis_address time.sleep(0.1) # Start monitoring the processes. - monitor_stdout_file, monitor_stderr_file = new_log_files( - "monitor", redirect_output) + monitor_stdout_file, monitor_stderr_file = new_monitor_log_file( + redirect_output) start_monitor( redis_address, node_ip_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, cleanup=cleanup, - autoscaling_config=autoscaling_config) - if use_raylet: - start_raylet_monitor( - redis_address, - stdout_file=monitor_stdout_file, - stderr_file=monitor_stderr_file, - cleanup=cleanup) + autoscaling_config=autoscaling_config, + redis_password=redis_password) + start_raylet_monitor( + redis_address, + stdout_file=monitor_stdout_file, + stderr_file=monitor_stderr_file, + cleanup=cleanup, + redis_password=redis_password, + config=config) if redis_shards == []: # Get redis shards from primary redis instance. redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) redis_shards = [ray.utils.decode(shard) for shard in redis_shards] address_info["redis_shards"] = redis_shards # Start the log monitor, if necessary. if include_log_monitor: - log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( - "log_monitor", redirect_output=True) + log_monitor_stdout_file, log_monitor_stderr_file = ( + new_log_monitor_log_file()) start_log_monitor( redis_address, node_ip_address, stdout_file=log_monitor_stdout_file, stderr_file=log_monitor_stderr_file, - cleanup=cleanup) - - # Start the global scheduler, if necessary. - if include_global_scheduler and not use_raylet: - global_scheduler_stdout_file, global_scheduler_stderr_file = ( - new_log_files("global_scheduler", redirect_output)) - start_global_scheduler( - redis_address, - node_ip_address, - stdout_file=global_scheduler_stdout_file, - stderr_file=global_scheduler_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) # Initialize with existing services. if "object_store_addresses" not in address_info: address_info["object_store_addresses"] = [] object_store_addresses = address_info["object_store_addresses"] - if "local_scheduler_socket_names" not in address_info: - address_info["local_scheduler_socket_names"] = [] - local_scheduler_socket_names = address_info["local_scheduler_socket_names"] if "raylet_socket_names" not in address_info: address_info["raylet_socket_names"] = [] raylet_socket_names = address_info["raylet_socket_names"] # Get the ports to use for the object managers if any are provided. - object_manager_ports = (address_info["object_manager_ports"] if - "object_manager_ports" in address_info else None) if not isinstance(object_manager_ports, list): + assert object_manager_ports is None or num_local_schedulers == 1 object_manager_ports = num_local_schedulers * [object_manager_ports] assert len(object_manager_ports) == num_local_schedulers + if not isinstance(node_manager_ports, list): + assert node_manager_ports is None or num_local_schedulers == 1 + node_manager_ports = num_local_schedulers * [node_manager_ports] + assert len(node_manager_ports) == num_local_schedulers # Start any object stores that do not yet exist. for i in range(num_local_schedulers - len(object_store_addresses)): # Start Plasma. - plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( - "plasma_store_{}".format(i), redirect_output) - plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( - "plasma_manager_{}".format(i), redirect_output) + plasma_store_stdout_file, plasma_store_stderr_file = ( + new_plasma_store_log_file(i, redirect_output)) + object_store_address = start_plasma_store( node_ip_address, redis_address, - object_manager_port=object_manager_ports[i], store_stdout_file=plasma_store_stdout_file, store_stderr_file=plasma_store_stderr_file, - manager_stdout_file=plasma_manager_stdout_file, - manager_stderr_file=plasma_manager_stderr_file, - objstore_memory=object_store_memory, + object_store_memory=object_store_memory, cleanup=cleanup, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + plasma_store_socket_name=plasma_store_socket_name, + redis_password=redis_password) object_store_addresses.append(object_store_address) time.sleep(0.1) - if not use_raylet: - # Start any local schedulers that do not yet exist. - for i in range( - len(local_scheduler_socket_names), num_local_schedulers): - # Connect the local scheduler to the object store at the same - # index. - object_store_address = object_store_addresses[i] - plasma_address = "{}:{}".format(node_ip_address, - object_store_address.manager_port) - # Determine how many workers this local scheduler should start. - if start_workers_from_local_scheduler: - num_local_scheduler_workers = workers_per_local_scheduler[i] - workers_per_local_scheduler[i] = 0 - else: - # If we're starting the workers from Python, the local - # scheduler should not start any workers. - num_local_scheduler_workers = 0 - # Start the local scheduler. Note that if we do not wish to - # redirect the worker output, then we cannot redirect the local - # scheduler output. - local_scheduler_stdout_file, local_scheduler_stderr_file = ( - new_log_files( - "local_scheduler_{}".format(i), - redirect_output=redirect_worker_output)) - local_scheduler_name = start_local_scheduler( + # Start any raylets that do not exist yet. + for i in range(len(raylet_socket_names), num_local_schedulers): + raylet_stdout_file, raylet_stderr_file = new_raylet_log_file( + i, redirect_output=redirect_worker_output) + address_info["raylet_socket_names"].append( + start_raylet( redis_address, node_ip_address, - object_store_address.name, - object_store_address.manager_name, + raylet_socket_name or get_raylet_socket_name(), + object_store_addresses[i], worker_path, - plasma_address=plasma_address, - stdout_file=local_scheduler_stdout_file, - stderr_file=local_scheduler_stderr_file, - cleanup=cleanup, + object_manager_port=object_manager_ports[i], + node_manager_port=node_manager_ports[i], resources=resources[i], - num_workers=num_local_scheduler_workers) - local_scheduler_socket_names.append(local_scheduler_name) - - # Make sure that we have exactly num_local_schedulers instances of - # object stores and local schedulers. - assert len(object_store_addresses) == num_local_schedulers - assert len(local_scheduler_socket_names) == num_local_schedulers - - else: - # Start any raylets that do not exist yet. - for i in range(len(raylet_socket_names), num_local_schedulers): - raylet_stdout_file, raylet_stderr_file = new_log_files( - "raylet_{}".format(i), redirect_output=redirect_worker_output) - address_info["raylet_socket_names"].append( - start_raylet( - redis_address, - node_ip_address, - object_store_addresses[i].name, - worker_path, - resources=resources[i], - num_workers=workers_per_local_scheduler[i], - stdout_file=raylet_stdout_file, - stderr_file=raylet_stderr_file, - cleanup=cleanup)) - - if not use_raylet: - # Start any workers that the local scheduler has not already started. - for i, num_local_scheduler_workers in enumerate( - workers_per_local_scheduler): - object_store_address = object_store_addresses[i] - local_scheduler_name = local_scheduler_socket_names[i] - for j in range(num_local_scheduler_workers): - worker_stdout_file, worker_stderr_file = new_log_files( - "worker_{}_{}".format(i, j), redirect_output) - start_worker( - node_ip_address, - object_store_address.name, - object_store_address.manager_name, - local_scheduler_name, - redis_address, - worker_path, - stdout_file=worker_stdout_file, - stderr_file=worker_stderr_file, - cleanup=cleanup) - workers_per_local_scheduler[i] -= 1 - - # Make sure that we've started all the workers. - assert (sum(workers_per_local_scheduler) == 0) + num_workers=workers_per_local_scheduler[i], + stdout_file=raylet_stdout_file, + stderr_file=raylet_stderr_file, + cleanup=cleanup, + redis_password=redis_password, + collect_profiling_data=collect_profiling_data, + config=config)) # Try to start the web UI. if include_webui: - ui_stdout_file, ui_stderr_file = new_log_files( - "webui", redirect_output=True) + ui_stdout_file, ui_stderr_file = new_webui_log_file() address_info["webui_url"] = start_ui( redis_address, stdout_file=ui_stdout_file, @@ -1627,9 +1553,11 @@ def start_ray_processes(address_info=None, def start_ray_node(node_ip_address, redis_address, object_manager_ports=None, - num_workers=0, + node_manager_ports=None, + num_workers=None, num_local_schedulers=1, object_store_memory=None, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1637,7 +1565,10 @@ def start_ray_node(node_ip_address, resources=None, plasma_directory=None, huge_pages=False, - use_raylet=False): + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None, + _internal_config=None): """Start the Ray processes for a single node. This assumes that the Ray processes on some master node have already been @@ -1649,12 +1580,17 @@ def start_ray_node(node_ip_address, object_manager_ports (list): A list of the ports to use for the object managers. There should be one per object manager being started on this node (typically just one). + node_manager_ports (list): A list of the ports to use for the node + managers. There should be one per node manager being started on + this node (typically just one). num_workers (int): The number of workers to start. num_local_schedulers (int): The number of local schedulers to start. This is also the number of plasma stores and plasma managers to start. object_store_memory (int): The maximum amount of memory (in bytes) to let the plasma store use. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1670,8 +1606,14 @@ def start_ray_node(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. + _internal_config (str): JSON configuration for overriding + RayConfig defaults. For testing purposes ONLY. Returns: A dictionary of the address information for the processes that were @@ -1679,14 +1621,16 @@ def start_ray_node(node_ip_address, """ address_info = { "redis_address": redis_address, - "object_manager_ports": object_manager_ports } return start_ray_processes( address_info=address_info, + object_manager_ports=object_manager_ports, + node_manager_ports=node_manager_ports, node_ip_address=node_ip_address, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_password=redis_password, worker_path=worker_path, include_log_monitor=True, cleanup=cleanup, @@ -1695,16 +1639,23 @@ def start_ray_node(node_ip_address, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=_internal_config) def start_ray_head(address_info=None, + object_manager_ports=None, + node_manager_ports=None, node_ip_address="127.0.0.1", redis_port=None, redis_shard_ports=None, - num_workers=0, + num_workers=None, num_local_schedulers=1, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1713,18 +1664,27 @@ def start_ray_head(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, - redis_protected_mode=False, + redis_password=None, include_webui=True, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False): + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None, + _internal_config=None): """Start Ray in local mode. Args: address_info (dict): A dictionary with address information for processes that have already been started. If provided, address_info will be modified to include processes that are newly started. + object_manager_ports (list): A list of the ports to use for the object + managers. There should be one per object manager being started on + this node (typically just one). + node_manager_ports (list): A list of the ports to use for the node + managers. There should be one per node manager being started on + this node (typically just one). node_ip_address (str): The IP address of this node. redis_port (int): The port that the primary Redis shard should listen to. If None, then a random port will be chosen. If the key @@ -1741,6 +1701,11 @@ def start_ray_head(address_info=None, address_info. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1759,17 +1724,22 @@ def start_ray_head(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. - redis_protected_mode: True if we should start Redis in protected mode. - This will prevent clients from other machines from connecting and - is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. include_webui: True if the UI should be started and false otherwise. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. + _internal_config (str): JSON configuration for overriding + RayConfig defaults. For testing purposes ONLY. Returns: A dictionary of the address information for the processes that were @@ -1778,79 +1748,31 @@ def start_ray_head(address_info=None, num_redis_shards = 1 if num_redis_shards is None else num_redis_shards return start_ray_processes( address_info=address_info, + object_manager_ports=object_manager_ports, + node_manager_ports=node_manager_ports, node_ip_address=node_ip_address, redis_port=redis_port, redis_shard_ports=redis_shard_ports, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, worker_path=worker_path, cleanup=cleanup, redirect_worker_output=redirect_worker_output, redirect_output=redirect_output, - include_global_scheduler=True, include_log_monitor=True, include_webui=include_webui, start_workers_from_local_scheduler=start_workers_from_local_scheduler, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - redis_protected_mode=redis_protected_mode, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) - - -def try_to_create_directory(directory_path): - """Attempt to create a directory that is globally readable/writable. - - Args: - directory_path: The path of the directory to create. - """ - if not os.path.exists(directory_path): - try: - os.makedirs(directory_path) - except OSError as e: - if e.errno != os.errno.EEXIST: - raise e - logger.warning( - "Attempted to create '{}', but the directory already " - "exists.".format(directory_path)) - # Change the log directory permissions so others can use it. This is - # important when multiple people are using the same machine. - os.chmod(directory_path, 0o0777) - - -def new_log_files(name, redirect_output): - """Generate partially randomized filenames for log files. - - Args: - name (str): descriptive string for this log file. - redirect_output (bool): True if files should be generated for logging - stdout and stderr and false if stdout and stderr should not be - redirected. - - Returns: - If redirect_output is true, this will return a tuple of two - filehandles. The first is for redirecting stdout and the second is - for redirecting stderr. If redirect_output is false, this will - return a tuple of two None objects. - """ - if not redirect_output: - return None, None - - # Create a directory to be used for process log files. - logs_dir = "/tmp/raylogs" - try_to_create_directory(logs_dir) - # Create another directory that will be used by some of the RL algorithms. - try_to_create_directory("/tmp/ray") - - log_id = random.randint(0, 10000) - date_str = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - log_stdout = "{}/{}-{}-{:05d}.out".format(logs_dir, name, date_str, log_id) - log_stderr = "{}/{}-{}-{:05d}.err".format(logs_dir, name, date_str, log_id) - # Line-buffer the output (mode 1) - log_stdout_file = open(log_stdout, "a", buffering=1) - log_stderr_file = open(log_stderr, "a", buffering=1) - return log_stdout_file, log_stderr_file + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=_internal_config) diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py new file mode 100644 index 0000000000000..d4e94aec8a2ae --- /dev/null +++ b/python/ray/tempfile_services.py @@ -0,0 +1,226 @@ +import binascii +import collections +import datetime +import errno +import logging +import os +import shutil +import tempfile + +import ray.utils + +logger = logging.getLogger(__name__) +_incremental_dict = collections.defaultdict(lambda: 0) +_temp_root = None + + +def make_inc_temp(suffix="", prefix="", directory_name="/tmp/ray"): + """Return a incremental temporary file name. The file is not created. + + Args: + suffix (str): The suffix of the temp file. + prefix (str): The prefix of the temp file. + directory_name (str) : The base directory of the temp file. + + Returns: + A string of file name. If there existing a file having the same name, + the returned name will look like + "{directory_name}/{prefix}.{unique_index}{suffix}" + """ + directory_name = os.path.expanduser(directory_name) + index = _incremental_dict[suffix, prefix, directory_name] + # `tempfile.TMP_MAX` could be extremely large, + # so using `range` in Python2.x should be avoided. + while index < tempfile.TMP_MAX: + if index == 0: + filename = os.path.join(directory_name, prefix + suffix) + else: + filename = os.path.join(directory_name, + prefix + "." + str(index) + suffix) + index += 1 + if not os.path.exists(filename): + _incremental_dict[suffix, prefix, + directory_name] = index # Save the index. + return filename + + raise FileExistsError(errno.EEXIST, "No usable temporary filename found") + + +def try_to_create_directory(directory_path): + """Attempt to create a directory that is globally readable/writable. + + Args: + directory_path: The path of the directory to create. + """ + directory_path = os.path.expanduser(directory_path) + if not os.path.exists(directory_path): + try: + os.makedirs(directory_path) + except OSError as e: + if e.errno != os.errno.EEXIST: + raise e + logger.warning( + "Attempted to create '{}', but the directory already " + "exists.".format(directory_path)) + # Change the log directory permissions so others can use it. This is + # important when multiple people are using the same machine. + os.chmod(directory_path, 0o0777) + + +def get_temp_root(): + """Get the path of the temporary root. If not existing, it will be created. + """ + global _temp_root + + date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + + # Lazy creation. Avoid creating directories never used. + if _temp_root is None: + _temp_root = make_inc_temp( + prefix="session_{date_str}_{pid}".format( + pid=os.getpid(), date_str=date_str), + directory_name="/tmp/ray") + try_to_create_directory(_temp_root) + return _temp_root + + +def set_temp_root(path): + """Set the path of the temporary root. It will be created lazily.""" + global _temp_root + _temp_root = path + + +def get_logs_dir_path(): + """Get a temp dir for logging.""" + logs_dir = os.path.join(get_temp_root(), "logs") + try_to_create_directory(logs_dir) + return logs_dir + + +def get_sockets_dir_path(): + """Get a temp dir for sockets.""" + sockets_dir = os.path.join(get_temp_root(), "sockets") + try_to_create_directory(sockets_dir) + return sockets_dir + + +def get_raylet_socket_name(suffix=""): + """Get a socket name for raylet.""" + sockets_dir = get_sockets_dir_path() + + raylet_socket_name = make_inc_temp( + prefix="raylet", directory_name=sockets_dir, suffix=suffix) + return raylet_socket_name + + +def get_object_store_socket_name(): + """Get a socket name for plasma object store.""" + sockets_dir = get_sockets_dir_path() + return make_inc_temp(prefix="plasma_store", directory_name=sockets_dir) + + +def get_ipython_notebook_path(port): + """Get a new ipython notebook path""" + + notebook_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb") + # We copy the notebook file so that the original doesn't get modified by + # the user. + notebook_name = make_inc_temp( + suffix=".ipynb", prefix="ray_ui", directory_name=get_temp_root()) + shutil.copy(notebook_filepath, notebook_name) + new_notebook_directory = os.path.dirname(notebook_name) + token = ray.utils.decode(binascii.hexlify(os.urandom(24))) + webui_url = ("http://localhost:{}/notebooks/{}?token={}".format( + port, os.path.basename(notebook_name), token)) + return new_notebook_directory, webui_url, token + + +def new_log_files(name, redirect_output): + """Generate partially randomized filenames for log files. + + Args: + name (str): descriptive string for this log file. + redirect_output (bool): True if files should be generated for logging + stdout and stderr and false if stdout and stderr should not be + redirected. + + Returns: + If redirect_output is true, this will return a tuple of two + filehandles. The first is for redirecting stdout and the second is + for redirecting stderr. If redirect_output is false, this will + return a tuple of two None objects. + """ + if not redirect_output: + return None, None + + # Create a directory to be used for process log files. + logs_dir = get_logs_dir_path() + # Create another directory that will be used by some of the RL algorithms. + + # TODO(suquark): This is done by the old code. + # We should be able to control its path later. + try_to_create_directory("/tmp/ray") + + log_stdout = make_inc_temp( + suffix=".out", prefix=name, directory_name=logs_dir) + log_stderr = make_inc_temp( + suffix=".err", prefix=name, directory_name=logs_dir) + # Line-buffer the output (mode 1) + log_stdout_file = open(log_stdout, "a", buffering=1) + log_stderr_file = open(log_stderr, "a", buffering=1) + return log_stdout_file, log_stderr_file + + +def new_redis_log_file(redirect_output, shard_number=None): + """Create new logging files for redis""" + if shard_number is None: + redis_stdout_file, redis_stderr_file = new_log_files( + "redis", redirect_output) + else: + redis_stdout_file, redis_stderr_file = new_log_files( + "redis-shard_{}".format(shard_number), redirect_output) + return redis_stdout_file, redis_stderr_file + + +def new_raylet_log_file(local_scheduler_index, redirect_output): + """Create new logging files for raylet.""" + raylet_stdout_file, raylet_stderr_file = new_log_files( + "raylet_{}".format(local_scheduler_index), + redirect_output=redirect_output) + return raylet_stdout_file, raylet_stderr_file + + +def new_webui_log_file(): + """Create new logging files for web ui.""" + ui_stdout_file, ui_stderr_file = new_log_files( + "webui", redirect_output=True) + return ui_stdout_file, ui_stderr_file + + +def new_worker_redirected_log_file(worker_id): + """Create new logging files for workers to redirect its output.""" + worker_stdout_file, worker_stderr_file = (new_log_files( + "worker-" + ray.utils.binary_to_hex(worker_id), True)) + return worker_stdout_file, worker_stderr_file + + +def new_log_monitor_log_file(): + """Create new logging files for the log monitor.""" + log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( + "log_monitor", redirect_output=True) + return log_monitor_stdout_file, log_monitor_stderr_file + + +def new_plasma_store_log_file(local_scheduler_index, redirect_output): + """Create new logging files for the plasma store.""" + plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( + "plasma_store_{}".format(local_scheduler_index), redirect_output) + return plasma_store_stdout_file, plasma_store_stderr_file + + +def new_monitor_log_file(redirect_output): + """Create new logging files for the monitor.""" + monitor_stdout_file, monitor_stderr_file = new_log_files( + "monitor", redirect_output) + return monitor_stdout_file, monitor_stderr_file diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py new file mode 100644 index 0000000000000..41dc3b6cdd26a --- /dev/null +++ b/python/ray/test/cluster_utils.py @@ -0,0 +1,226 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import atexit +import logging +import time + +import ray +import ray.services as services + +logger = logging.getLogger(__name__) + + +class Cluster(object): + def __init__(self, + initialize_head=False, + connect=False, + head_node_args=None, + shutdown_at_exit=True): + """Initializes the cluster. + + Args: + initialize_head (bool): Automatically start a Ray cluster + by initializing the head node. Defaults to False. + connect (bool): If `initialize_head=True` and `connect=True`, + ray.init will be called with the redis address of this cluster + passed in. + head_node_args (dict): Arguments to be passed into + `start_ray_head` via `self.add_node`. + shutdown_at_exit (bool): If True, registers an exit hook + for shutting down all started processes. + """ + self.head_node = None + self.worker_nodes = {} + self.redis_address = None + if not initialize_head and connect: + raise RuntimeError("Cannot connect to uninitialized cluster.") + + if initialize_head: + head_node_args = head_node_args or {} + self.add_node(**head_node_args) + if connect: + redis_password = head_node_args.get("redis_password") + output_info = ray.init( + redis_address=self.redis_address, + redis_password=redis_password) + logger.info(output_info) + if shutdown_at_exit: + atexit.register(self.shutdown) + + def add_node(self, **override_kwargs): + """Adds a node to the local Ray Cluster. + + All nodes are by default started with the following settings: + cleanup=True, + resources={"CPU": 1}, + object_store_memory=100 * (2**20) # 100 MB + + Args: + override_kwargs: Keyword arguments used in `start_ray_head` + and `start_ray_node`. Overrides defaults. + + Returns: + Node object of the added Ray node. + """ + node_kwargs = { + "cleanup": True, + "resources": { + "CPU": 1 + }, + "object_store_memory": 100 * (2**20) # 100 MB + } + node_kwargs.update(override_kwargs) + + if self.head_node is None: + address_info = services.start_ray_head( + node_ip_address=services.get_node_ip_address(), + include_webui=False, + **node_kwargs) + self.redis_address = address_info["redis_address"] + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.head_node = node + else: + address_info = services.start_ray_node( + services.get_node_ip_address(), self.redis_address, + **node_kwargs) + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.worker_nodes[node] = address_info + logger.info("Starting Node with raylet socket {}".format( + address_info["raylet_socket_names"])) + + return node + + def remove_node(self, node): + """Kills all processes associated with worker node. + + Args: + node (Node): Worker node of which all associated processes + will be removed. + """ + if self.head_node == node: + self.head_node.kill_all_processes() + self.head_node = None + # TODO(rliaw): Do we need to kill all worker processes? + else: + node.kill_all_processes() + self.worker_nodes.pop(node) + + assert not node.any_processes_alive(), ( + "There are zombie processes left over after killing.") + + def wait_for_nodes(self, retries=30): + """Waits for all nodes to be registered with global state. + + By default, waits for 3 seconds. + + Args: + retries (int): Number of times to retry checking client table. + + Returns: + True if successfully registered nodes as expected. + """ + + for i in range(retries): + if not ray.is_initialized() or not self._check_registered_nodes(): + time.sleep(0.1) + else: + return True + return False + + def _check_registered_nodes(self): + registered = len([ + client for client in ray.global_state.client_table() + if client["IsInsertion"] + ]) + expected = len(self.list_all_nodes()) + if registered == expected: + logger.info("All nodes registered as expected.") + else: + logger.info("Currently registering {} but expecting {}".format( + registered, expected)) + return registered == expected + + def list_all_nodes(self): + """Lists all nodes. + + TODO(rliaw): What is the desired behavior if a head node + dies before worker nodes die? + + Returns: + List of all nodes, including the head node. + """ + nodes = list(self.worker_nodes) + if self.head_node: + nodes = [self.head_node] + nodes + return nodes + + def shutdown(self): + """Removes all nodes.""" + + # We create a list here as a copy because `remove_node` + # modifies `self.worker_nodes`. + all_nodes = list(self.worker_nodes) + for node in all_nodes: + self.remove_node(node) + + if self.head_node: + self.remove_node(self.head_node) + else: + logger.warning("No headnode exists!") + + +class Node(object): + """Abstraction for a Ray node.""" + + def __init__(self, process_dict): + # TODO(rliaw): Is there a unique identifier for a node? + self.process_dict = process_dict + + def kill_plasma_store(self): + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].kill() + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].wait() + + def kill_raylet(self): + self.process_dict[services.PROCESS_TYPE_RAYLET][0].kill() + self.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + + def kill_log_monitor(self): + self.process_dict["log_monitor"][0].kill() + self.process_dict["log_monitor"][0].wait() + + def kill_all_processes(self): + for process_name, process_list in self.process_dict.items(): + logger.info("Killing all {}(s)".format(process_name)) + for process in process_list: + # Kill the process if it is still alive. + if process.poll() is None: + process.kill() + + for process_name, process_list in self.process_dict.items(): + logger.info("Waiting all {}(s)".format(process_name)) + for process in process_list: + process.wait() + + def live_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is None] + + def dead_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is not None] + + def any_processes_alive(self): + return any(self.live_processes()) + + def all_processes_alive(self): + return not any(self.dead_processes()) diff --git a/python/ray/test/test_global_state.py b/python/ray/test/test_global_state.py index 7b12ee0227902..68805a8ec5612 100644 --- a/python/ray/test/test_global_state.py +++ b/python/ray/test/test_global_state.py @@ -2,57 +2,107 @@ from __future__ import division from __future__ import print_function +import json +import pytest +try: + import pytest_timeout +except ImportError: + pytest_timeout = None import time import ray - - -def setup_module(): - if not ray.worker.global_worker.connected: - ray.init(num_cpus=1) - - # Finish initializing Ray. Otherwise available_resources() does not - # reflect resource use of submitted tasks - ray.get(cpu_task.remote(0)) - - -@ray.remote(num_cpus=1) -def cpu_task(seconds): - time.sleep(seconds) - - -class TestAvailableResources(object): - timeout = 10 - - def test_no_tasks(self): - cluster_resources = ray.global_state.cluster_resources() +from ray.test.cluster_utils import Cluster + + +@pytest.fixture +def ray_start(): + # Start the Ray processes. + ray.init(num_cpus=1) + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def cluster_start(): + # Start the Ray processes. + cluster = Cluster( + initialize_head=True, + connect=True, + head_node_args={ + "resources": dict(CPU=1), + "_internal_config": json.dumps({ + "num_heartbeats_timeout": 10 + }) + }) + yield cluster + ray.shutdown() + cluster.shutdown() + + +# TODO(rliaw): The proper way to do this is to have the pytest config setup. +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test that may hang.") +@pytest.mark.timeout(10) +def test_replenish_resources(ray_start): + cluster_resources = ray.global_state.cluster_resources() + available_resources = ray.global_state.available_resources() + assert cluster_resources == available_resources + + @ray.remote + def cpu_task(): + pass + + ray.get(cpu_task.remote()) + resources_reset = False + + while not resources_reset: available_resources = ray.global_state.available_resources() - assert cluster_resources == available_resources - - def test_replenish_resources(self): - cluster_resources = ray.global_state.cluster_resources() + resources_reset = (cluster_resources == available_resources) + assert resources_reset - ray.get(cpu_task.remote(0)) - start = time.time() - resources_reset = False - while not resources_reset and time.time() - start < self.timeout: - available_resources = ray.global_state.available_resources() - resources_reset = (cluster_resources == available_resources) +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test that may hang.") +@pytest.mark.timeout(10) +def test_uses_resources(ray_start): + cluster_resources = ray.global_state.cluster_resources() - assert resources_reset + @ray.remote + def cpu_task(): + time.sleep(1) - def test_uses_resources(self): - cluster_resources = ray.global_state.cluster_resources() - task_id = cpu_task.remote(1) - start = time.time() - resource_used = False + cpu_task.remote() + resource_used = False - while not resource_used and time.time() - start < self.timeout: - available_resources = ray.global_state.available_resources() - resource_used = available_resources[ - "CPU"] == cluster_resources["CPU"] - 1 - - assert resource_used - - ray.get(task_id) # clean up to reset resources + while not resource_used: + available_resources = ray.global_state.available_resources() + resource_used = available_resources[ + "CPU"] == cluster_resources["CPU"] - 1 + + assert resource_used + + +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test that may hang.") +@pytest.mark.timeout(20) +def test_add_remove_cluster_resources(cluster_start): + """Tests that Global State API is consistent with actual cluster.""" + cluster = cluster_start + assert ray.global_state.cluster_resources()["CPU"] == 1 + nodes = [] + nodes += [cluster.add_node(resources=dict(CPU=1))] + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 2 + + cluster.remove_node(nodes.pop()) + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 1 + + for i in range(5): + nodes += [cluster.add_node(resources=dict(CPU=1))] + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 6 diff --git a/python/ray/test/test_modin.py b/python/ray/test/test_modin.py new file mode 100644 index 0000000000000..83c11895ec7b8 --- /dev/null +++ b/python/ray/test/test_modin.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray # noqa F401 + + +def test_modin_import(): + import modin.pandas as pd + frame_data = [1, 2, 3, 4, 5, 6, 7, 8] + frame = pd.DataFrame(frame_data) + assert frame.sum().squeeze() == sum(frame_data) diff --git a/python/ray/test/test_ray_init.py b/python/ray/test/test_ray_init.py new file mode 100644 index 0000000000000..2b17ce35ef286 --- /dev/null +++ b/python/ray/test/test_ray_init.py @@ -0,0 +1,70 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import redis + +import ray +from ray.test.cluster_utils import Cluster + + +@pytest.fixture +def password(): + random_bytes = os.urandom(128) + if hasattr(random_bytes, "hex"): + return random_bytes.hex() # Python 3 + return random_bytes.encode("hex") # Python 2 + + +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestRedisPassword(object): + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="New GCS API doesn't support Redis authentication yet.") + def test_redis_password(self, password, shutdown_only): + @ray.remote + def f(): + return 1 + + info = ray.init(redis_password=password) + redis_address = info["redis_address"] + redis_ip, redis_port = redis_address.split(":") + + # Check that we can run a task + object_id = f.remote() + ray.get(object_id) + + # Check that Redis connections require a password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=None) + with pytest.raises(redis.ResponseError): + redis_client.ping() + + # Check that we can connect to Redis using the provided password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=password) + assert redis_client.ping() + + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="New GCS API doesn't support Redis authentication yet.") + def test_redis_password_cluster(self, password, shutdown_only): + @ray.remote + def f(): + return 1 + + node_args = {"redis_password": password} + cluster = Cluster( + initialize_head=True, connect=True, head_node_args=node_args) + cluster.add_node(**node_args) + + object_id = f.remote() + ray.get(object_id) diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index a29daa5a073ec..a3614650e97ba 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -6,6 +6,7 @@ import os import redis import subprocess +import sys import tempfile import time @@ -34,22 +35,11 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20): client_table = ray.global_state.client_table() num_ready_nodes = len(client_table) if num_ready_nodes == num_nodes: - ready = True # Check that for each node, a local scheduler and a plasma manager # are present. - if ray.global_state.use_raylet: - # In raylet mode, this is a list of map. - # The GCS info will appear as a whole instead of part by part. - return - else: - for ip_address, clients in client_table.items(): - client_types = [client["ClientType"] for client in clients] - if "local_scheduler" not in client_types: - ready = False - if "plasma_manager" not in client_types: - ready = False - if ready: - return + # In raylet mode, this is a list of map. + # The GCS info will appear as a whole instead of part by part. + return if num_ready_nodes > num_nodes: # Too many nodes have joined. Something must be wrong. raise Exception("{} nodes have joined the cluster, but we were " @@ -147,3 +137,40 @@ def run_and_get_output(command): with open(tmp.name, 'r') as f: result = f.readlines() return "\n".join(result) + + +def run_string_as_driver(driver_script): + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + + Returns: + The script's output. + """ + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = ray.utils.decode( + subprocess.check_output([sys.executable, f.name])) + return out + + +def run_string_as_driver_nonblocking(driver_script): + """Start a driver as a separate process and return immediately. + + Args: + driver_script: A string to run as a Python script. + + Returns: + A handle to the driver process. + """ + # Save the driver script as a file so we can call it using subprocess. We + # do not delete this file because if we do then it may get removed before + # the Python process tries to run it. + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(driver_script.encode("ascii")) + f.flush() + return subprocess.Popen( + [sys.executable, f.name], stdout=subprocess.PIPE) diff --git a/python/ray/tune/README.rst b/python/ray/tune/README.rst index 2d7533f56a0f6..5635ab3ff8856 100644 --- a/python/ray/tune/README.rst +++ b/python/ray/tune/README.rst @@ -6,6 +6,14 @@ Tune is a scalable framework for hyperparameter search with a focus on deep lear User documentation can be `found here `__. +Tutorial +-------- + +To get started with Tune, try going through `our tutorial of using Tune with Keras `__. + +(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. + + Citing Tune ----------- diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 83d4f4fdece37..1e341b26526ea 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -7,9 +7,16 @@ from ray.tune.experiment import Experiment from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable -from ray.tune.suggest import grid_search, function +from ray.tune.suggest import grid_search, function, sample_from __all__ = [ - "Trainable", "TuneError", "grid_search", "register_env", - "register_trainable", "run_experiments", "Experiment", "function" + "Trainable", + "TuneError", + "grid_search", + "register_env", + "register_trainable", + "run_experiments", + "Experiment", + "function", + "sample_from", ] diff --git a/python/ray/tune/examples/README.rst b/python/ray/tune/examples/README.rst index 3d35497c8841a..a762a057021c3 100644 --- a/python/ray/tune/examples/README.rst +++ b/python/ray/tune/examples/README.rst @@ -1,4 +1,60 @@ Tune Examples ============= -Code examples for various schedulers and Tune features. +.. Keep this in sync with ray/doc/tune-examples.rst + +In our repository, we provide a variety of examples for the various use cases and features of Tune. + +If any example is broken, or if you'd like to add an example to this page, feel free to raise an issue on our Github repository. + + +General Examples +---------------- + +- `async_hyperband_example `__: + Example of using a Trainable class with AsyncHyperBandScheduler. +- `hyperband_example `__: + Example of using a Trainable class with HyperBandScheduler. Also uses the Experiment class API for specifying the experiment configuration. +- `hyperopt_example `__: + Optimizes a basic function using the function-based API and the HyperOptSearch (SearchAlgorithm wrapper for HyperOpt TPE). + Also uses the AsyncHyperBandScheduler. +- `pbt_example `__: + Example of using a Trainable class with PopulationBasedTraining scheduler. +- `pbt_ppo_example `__: + Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. + + +Keras Examples +-------------- + +- `tune_mnist_keras `__: + Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune. + + +PyTorch Examples +---------------- + +- `mnist_pytorch `__: + Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune. +- `mnist_pytorch_trainable `__: + Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end. + + +TensorFlow Examples +------------------- + +- `tune_mnist_ray `__: + A basic example of tuning a TensorFlow model on MNIST using the Trainable class. +- `tune_mnist_ray_hyperband `__: + A basic example of tuning a TensorFlow model on MNIST using the Trainable class and the HyperBand scheduler. +- `tune_mnist_async_hyperband `__: + Example of tuning a TensorFlow model on MNIST using AsyncHyperBand. + + +Contributed Examples +-------------------- + +- `pbt_tune_cifar10_with_keras `__: + A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler. +- `genetic_example `__: + Optimizing the michalewicz function using the contributed GeneticSearch search algorithm with AsyncHyperBandScheduler. diff --git a/python/ray/rllib/examples/legacy_multiagent/__init__.py b/python/ray/tune/examples/__init__.py similarity index 100% rename from python/ray/rllib/examples/legacy_multiagent/__init__.py rename to python/ray/tune/examples/__init__.py diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py index 2c368b4e3d05e..871e8c1718ea0 100644 --- a/python/ray/tune/examples/async_hyperband_example.py +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -12,7 +12,7 @@ import numpy as np import ray -from ray.tune import Trainable, run_experiments +from ray.tune import Trainable, run_experiments, sample_from from ray.tune.schedulers import AsyncHyperBandScheduler @@ -23,7 +23,7 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self): + def _setup(self, config): self.timestep = 0 def _train(self): @@ -76,8 +76,10 @@ def _restore(self, checkpoint_path): "gpu": 0 }, "config": { - "width": lambda spec: 10 + int(90 * random.random()), - "height": lambda spec: int(100 * random.random()), + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from( + lambda spec: int(100 * random.random())), }, } }, diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index 94f603e8206cf..d403a0e0f8af1 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -12,7 +12,7 @@ import numpy as np import ray -from ray.tune import Trainable, run_experiments, Experiment +from ray.tune import Trainable, run_experiments, Experiment, sample_from from ray.tune.schedulers import HyperBandScheduler @@ -23,7 +23,7 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self): + def _setup(self, config): self.timestep = 0 def _train(self): @@ -67,8 +67,8 @@ def _restore(self, checkpoint_path): num_samples=20, stop={"training_iteration": 1 if args.smoke_test else 99999}, config={ - "width": lambda spec: 10 + int(90 * random.random()), - "height": lambda spec: int(100 * random.random()) + "width": sample_from(lambda spec: 10 + int(90 * random.random())), + "height": sample_from(lambda spec: int(100 * random.random())) }) run_experiments(exp, scheduler=hyperband) diff --git a/python/ray/tune/examples/hyperopt_example.py b/python/ray/tune/examples/hyperopt_example.py index 6d61b1321e2ed..2898bf26d8539 100644 --- a/python/ray/tune/examples/hyperopt_example.py +++ b/python/ray/tune/examples/hyperopt_example.py @@ -48,7 +48,7 @@ def easy_objective(config, reporter): "run": "exp", "num_samples": 10 if args.smoke_test else 1000, "stop": { - "training_iteration": 100 + "timesteps_total": 100 }, } } diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index bfd319bff3c9d..bec73f3d51249 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -182,8 +182,10 @@ def test(): "run": "train_mnist", "num_samples": 1 if args.smoke_test else 10, "config": { - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), } } }, diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 0d23c0cc21304..6005cd79c14fb 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -80,9 +80,9 @@ def forward(self, x): class TrainMNIST(Trainable): - def _setup(self): - args = self.config.pop("args") - vars(args).update(self.config) + def _setup(self, config): + args = config.pop("args") + vars(args).update(config) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) @@ -159,12 +159,13 @@ def _train(self): self._train_iteration() return self._test() - def _save(self, path): - torch.save(self.model.state_dict(), os.path.join(path, "model.pth")) - return path + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return checkpoint_path - def _restore(self, path): - self.model.load_state_dict(os.path.join(path, "model.pth")) + def _restore(self, checkpoint_path): + self.model.load_state_dict(checkpoint_path) if __name__ == '__main__': @@ -194,8 +195,10 @@ def _restore(self, path): "checkpoint_at_end": True, "config": { "args": args, - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), } } }, diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index c958d2512e83a..3433e82f94eed 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -18,7 +18,7 @@ class MyTrainableClass(Trainable): """Fake agent whose learning rate is determined by dummy factors.""" - def _setup(self): + def _setup(self, config): self.timestep = 0 self.current_value = 0.0 diff --git a/python/ray/tune/examples/pbt_ppo_example.py b/python/ray/tune/examples/pbt_ppo_example.py index efd7ee4a89580..a81d4109f62c1 100755 --- a/python/ray/tune/examples/pbt_ppo_example.py +++ b/python/ray/tune/examples/pbt_ppo_example.py @@ -13,7 +13,7 @@ import random import ray -from ray.tune import run_experiments +from ray.tune import run_experiments, sample_from from ray.tune.schedulers import PopulationBasedTraining if __name__ == "__main__": @@ -63,12 +63,12 @@ def explore(config): "clip_param": 0.2, "lr": 1e-4, # These params start off randomly drawn from a set. - "num_sgd_iter": - lambda spec: random.choice([10, 20, 30]), - "sgd_minibatch_size": - lambda spec: random.choice([128, 512, 2048]), - "train_batch_size": - lambda spec: random.choice([10000, 20000, 40000]) + "num_sgd_iter": sample_from( + lambda spec: random.choice([10, 20, 30])), + "sgd_minibatch_size": sample_from( + lambda spec: random.choice([128, 512, 2048])), + "train_batch_size": sample_from( + lambda spec: random.choice([10000, 20000, 40000])) }, }, }, diff --git a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py index 28575f5466824..692c967cf2946 100755 --- a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py +++ b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py @@ -23,7 +23,7 @@ from tensorflow.python.keras.preprocessing.image import ImageDataGenerator import ray -from ray.tune import grid_search, run_experiments +from ray.tune import grid_search, run_experiments, sample_from from ray.tune import Trainable from ray.tune.schedulers import PopulationBasedTraining @@ -105,7 +105,7 @@ def _build_model(self, input_shape): model = Model(inputs=x, outputs=y, name="model1") return model - def _setup(self): + def _setup(self, config): self.train_data, self.test_data = self._read_data() x_train = self.train_data[0] model = self._build_model(x_train.shape[1:]) @@ -193,7 +193,7 @@ def _stop(self): "epochs": 1, "batch_size": 64, "lr": grid_search([10**-4, 10**-5]), - "decay": lambda spec: spec.config.lr / 100.0, + "decay": sample_from(lambda spec: spec.config.lr / 100.0), "dropout": grid_search([0.25, 0.5]), }, "num_samples": 4, diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index 32cfb371efc46..cbe9f626df6f4 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -105,6 +105,8 @@ def create_parser(): parser = argparse.ArgumentParser(description='Keras MNIST Example') parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") + parser.add_argument( + "--use-gpu", action="store_true", help="Use GPU in training.") parser.add_argument( '--jobs', type=int, @@ -113,8 +115,8 @@ def create_parser(): parser.add_argument( '--threads', type=int, - default=None, - help='threads used in operations (default: all)') + default=2, + help='threads used in operations (default: 2)') parser.add_argument( '--steps', type=float, @@ -185,11 +187,19 @@ def create_parser(): }, "run": "train_mnist", "num_samples": 1 if args.smoke_test else 10, + "trial_resources": { + "cpu": args.threads, + "gpu": 0.5 if args.use_gpu else 0 + }, "config": { - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), - "hidden": lambda spec: np.random.randint(32, 512), - "dropout1": lambda spec: np.random.uniform(0.2, 0.8), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), + "hidden": tune.sample_from( + lambda spec: np.random.randint(32, 512)), + "dropout1": tune.sample_from( + lambda spec: np.random.uniform(0.2, 0.8)), } } }, diff --git a/python/ray/tune/examples/tune_mnist_ray.py b/python/ray/tune/examples/tune_mnist_ray.py index e806a1a68ac8e..e56ebd10f5ebc 100755 --- a/python/ray/tune/examples/tune_mnist_ray.py +++ b/python/ray/tune/examples/tune_mnist_ray.py @@ -42,7 +42,7 @@ FLAGS = None status_reporter = None # used to report training status back to Ray -activation_fn = None # e.g. tf.nn.relu +activation_fn = tf.nn.relu # e.g. tf.nn.relu def deepnn(x): diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index 29939ff243085..bce19deca6859 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -31,7 +31,7 @@ import ray from ray.tune import grid_search, run_experiments, register_trainable, \ - Trainable + Trainable, sample_from from ray.tune.schedulers import HyperBandScheduler from tensorflow.examples.tutorials.mnist import input_data @@ -128,7 +128,7 @@ def bias_variable(shape): class TrainMNIST(Trainable): """Example MNIST trainable.""" - def _setup(self): + def _setup(self, config): global activation_fn self.timestep = 0 @@ -148,7 +148,7 @@ def _setup(self): self.x = tf.placeholder(tf.float32, [None, 784]) self.y_ = tf.placeholder(tf.float32, [None, 10]) - activation_fn = getattr(tf.nn, self.config['activation']) + activation_fn = getattr(tf.nn, config['activation']) # Build the graph for the deep net y_conv, self.keep_prob = setupCNN(self.x) @@ -160,7 +160,7 @@ def _setup(self): with tf.name_scope('adam_optimizer'): train_step = tf.train.AdamOptimizer( - self.config['learning_rate']).minimize(cross_entropy) + config['learning_rate']).minimize(cross_entropy) self.train_step = train_step @@ -221,7 +221,8 @@ def _restore(self, path): 'time_total_s': 600, }, 'config': { - 'learning_rate': lambda spec: 10**np.random.uniform(-5, -3), + 'learning_rate': sample_from( + lambda spec: 10**np.random.uniform(-5, -3)), 'activation': grid_search(['relu', 'elu', 'tanh']), }, "num_samples": 10, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 390a66193d672..3a4ddc9c7aab8 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -85,7 +85,7 @@ def __init__(self, repeat=1, num_samples=1, local_dir=None, - upload_dir="", + upload_dir=None, checkpoint_freq=0, checkpoint_at_end=False, max_failures=3, @@ -97,7 +97,7 @@ def __init__(self, "trial_resources": trial_resources, "num_samples": num_samples, "local_dir": local_dir or DEFAULT_RESULTS_DIR, - "upload_dir": upload_dir, + "upload_dir": upload_dir or "", # argparse converts None to "null" "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, "max_failures": max_failures, diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index d1704b6aa94fe..47593f2213bcd 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -14,11 +14,12 @@ class StatusReporter(object): - """Object passed into your main() that you can report status through. + """Object passed into your function that you can report status through. Example: - >>> reporter = StatusReporter() - >>> reporter(timesteps_total=1) + >>> def trainable_function(config, reporter): + >>> assert isinstance(reporter, StatusReporter) + >>> reporter(timesteps_total=1) """ def __init__(self): @@ -33,6 +34,9 @@ def __call__(self, **kwargs): Args: kwargs: Latest training result status. + + Example: + >>> reporter(mean_accuracy=1, training_iteration=4) """ with self._lock: @@ -90,10 +94,10 @@ class FunctionRunner(Trainable): _name = "func" _default_config = DEFAULT_CONFIG - def _setup(self): + def _setup(self, config): entrypoint = self._trainable_func() self._status_reporter = StatusReporter() - scrubbed_config = self.config.copy() + scrubbed_config = config.copy() for k in self._default_config: if k in scrubbed_config: del scrubbed_config[k] diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index 2e18a86582085..109c11a01707f 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -107,11 +107,13 @@ def sync_now(self, force=False): if not distutils.spawn.find_executable("rsync"): logger.error("Log sync requires rsync to be installed.") return + source = '{}@{}:{}/'.format(ssh_user, self.worker_ip, + self.local_dir) + target = '{}/'.format(self.local_dir) worker_to_local_sync_cmd = (( - """rsync -avz -e "ssh -i {} -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format( - quote(ssh_key), ssh_user, self.worker_ip, - quote(self.local_dir), quote(self.local_dir))) + """rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """ + """-o StrictHostKeyChecking=no" {} {}""").format( + quote(ssh_key), quote(source), quote(target))) if self.remote_dir: if self.remote_dir.startswith(S3_PREFIX): diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index f73aa4a1ef8ce..183ba6490b979 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -9,6 +9,7 @@ import os import yaml +import ray.cloudpickle as cloudpickle from ray.tune.log_sync import get_syncer from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \ TIMESTEPS_TOTAL @@ -97,7 +98,15 @@ class _JsonLogger(Logger): def _init(self): config_out = os.path.join(self.logdir, "params.json") with open(config_out, "w") as f: - json.dump(self.config, f, sort_keys=True, cls=_SafeFallbackEncoder) + json.dump( + self.config, + f, + indent=2, + sort_keys=True, + cls=_SafeFallbackEncoder) + config_pkl = os.path.join(self.logdir, "params.pkl") + with open(config_pkl, "wb") as f: + cloudpickle.dump(self.config, f) local_file = os.path.join(self.logdir, "result.json") self.local_out = open(local_file, "w") diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 86f09cda34d8e..6b107b17c82f9 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -29,6 +29,8 @@ def __init__(self, queue_trials=False): self._avail_resources = Resources(cpu=0, gpu=0) self._committed_resources = Resources(cpu=0, gpu=0) self._resources_initialized = False + if ray.is_initialized(): + self._update_avail_resources() def _setup_runner(self, trial): cls = ray.remote( @@ -108,19 +110,27 @@ def _stop_trial(self, trial, error=False, error_msg=None, if stop_logger: trial.close_logger() - def start_trial(self, trial, checkpoint_obj=None): - """Starts the trial.""" + def start_trial(self, trial, checkpoint=None): + """Starts the trial. + + Will not return resources if trial repeatedly fails on start. + + Args: + trial (Trial): Trial to be started. + checkpoint (Checkpoint): A Python object or path storing the state + of trial. + """ self._commit_resources(trial.resources) try: - self._start_trial(trial, checkpoint_obj) + self._start_trial(trial, checkpoint) except Exception: logger.exception("Error stopping runner - retrying...") error_msg = traceback.format_exc() time.sleep(2) self._stop_trial(trial, error=True, error_msg=error_msg) try: - self._start_trial(trial) + self._start_trial(trial, checkpoint) except Exception: logger.exception("Error starting runner, aborting!") error_msg = traceback.format_exc() @@ -138,6 +148,7 @@ def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): self._stop_trial( trial, error=error, error_msg=error_msg, stop_logger=stop_logger) if prior_status == Trial.RUNNING: + logger.debug("Returning resources for this trial.") self._return_resources(trial.resources) out = self._find_item(self._running, trial) for result_id in out: @@ -213,19 +224,9 @@ def _return_resources(self, resources): assert self._committed_resources.gpu >= 0 def _update_avail_resources(self): - clients = ray.global_state.client_table() - if ray.worker.global_worker.use_raylet: - # TODO(rliaw): Remove once raylet flag is swapped - num_cpus = sum(cl['Resources']['CPU'] for cl in clients) - num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients) - else: - local_schedulers = [ - entry for client in clients.values() for entry in client - if (entry['ClientType'] == 'local_scheduler' - and not entry['Deleted']) - ] - num_cpus = sum(ls['CPU'] for ls in local_schedulers) - num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) + resources = ray.global_state.cluster_resources() + num_cpus = resources["CPU"] + num_gpus = resources["GPU"] self._avail_resources = Resources(int(num_cpus), int(num_gpus)) self._resources_initialized = True @@ -267,7 +268,16 @@ def debug_string(self): self._committed_resources.cpu, self._avail_resources.cpu, self._committed_resources.gpu, self._avail_resources.gpu) else: - return "" + return "Resources requested: ?" + + def resource_string(self): + """Returns a string describing the total resources available.""" + + if self._resources_initialized: + return "{} CPUs, {} GPUs".format(self._avail_resources.cpu, + self._avail_resources.gpu) + else: + return "? CPUs, ? GPUs" def on_step_begin(self): """Before step() called, update the available resources.""" @@ -277,6 +287,7 @@ def on_step_begin(self): def save(self, trial, storage=Checkpoint.DISK): """Saves the trial's state to a checkpoint.""" trial._checkpoint.storage = storage + trial._checkpoint.last_result = trial.last_result if storage == Checkpoint.MEMORY: trial._checkpoint.value = trial.runner.save_to_object.remote() else: @@ -300,6 +311,8 @@ def restore(self, trial, checkpoint=None): ray.get(trial.runner.restore_from_object.remote(value)) else: ray.get(trial.runner.restore.remote(value)) + trial.last_result = checkpoint.last_result + return True except Exception: logger.exception("Error restoring runner.") diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index ec307eaed8fb8..0d5aeb0d0618c 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -4,6 +4,8 @@ import os +# yapf: disable +# __sphinx_doc_begin__ # (Optional/Auto-filled) training is terminated. Filled only if not provided. DONE = "done" @@ -16,16 +18,16 @@ # (Auto-filled) The pid of the training process. PID = "pid" -# Number of timesteps in this iteration. +# Number of episodes in this iteration. EPISODES_THIS_ITER = "episodes_this_iter" -# (Optional/Auto-filled) Accumulated time in seconds for this experiment. +# (Optional/Auto-filled) Accumulated number of episodes for this experiment. EPISODES_TOTAL = "episodes_total" # Number of timesteps in this iteration. TIMESTEPS_THIS_ITER = "timesteps_this_iter" -# (Optional/Auto-filled) Accumulated time in seconds for this experiment. +# (Auto-filled) Accumulated number of timesteps for this entire experiment. TIMESTEPS_TOTAL = "timesteps_total" # (Auto-filled) Time in seconds this iteration took to run. @@ -35,11 +37,14 @@ # (Auto-filled) Accumulated time in seconds for this entire experiment. TIME_TOTAL_S = "time_total_s" -# (Auto-filled) The index of thistraining iteration. +# (Auto-filled) The index of this training iteration. TRAINING_ITERATION = "training_iteration" +# __sphinx_doc_end__ +# yapf: enable # Where Tune writes result files by default -DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results") +DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR") + or os.path.expanduser("~/ray_results")) # Meta file about status under each experiment directory, can be # parsed by automlboard if exists. diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 7e2f8f27e278f..71c69b3063a2b 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -50,7 +50,10 @@ class HyperBandScheduler(FIFOScheduler): For example, to limit trials to 10 minutes and early stop based on the `episode_mean_reward` attr, construct: - ``HyperBand('time_total_s', 'episode_reward_mean', 600)`` + ``HyperBand('time_total_s', 'episode_reward_mean', max_t=600)`` + + Note that Tune's stopping criteria will be applied in conjunction with + HyperBand's early stopping mechanisms. See also: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index f0146ca5e7992..9f4a5e6a7ad06 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -2,9 +2,15 @@ from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import SuggestionAlgorithm from ray.tune.suggest.hyperopt import HyperOptSearch -from ray.tune.suggest.variant_generator import grid_search, function +from ray.tune.suggest.variant_generator import grid_search, function, \ + sample_from __all__ = [ - "SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch", - "SuggestionAlgorithm", "grid_search", "function" + "SearchAlgorithm", + "BasicVariantGenerator", + "HyperOptSearch", + "SuggestionAlgorithm", + "grid_search", + "function", + "sample_from", ] diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 45fe9753e0ea0..2c1c1317616d3 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -4,9 +4,13 @@ import numpy as np import copy +import logging + try: + hyperopt_logger = logging.getLogger("hyperopt") + hyperopt_logger.setLevel(logging.WARNING) import hyperopt as hpo -except Exception as e: +except Exception: hpo = None from ray.tune.error import TuneError @@ -47,7 +51,6 @@ class HyperOptSearch(SuggestionAlgorithm): >>> } >>> algo = HyperOptSearch( >>> space, max_concurrent=4, reward_attr="neg_mean_loss") - >>> algo.add_configurations(config) """ def __init__(self, diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 98b830754093e..d57b586c0e5d2 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -3,12 +3,15 @@ from __future__ import print_function import copy +import logging import numpy import random import types from ray.tune import TuneError +logger = logging.getLogger(__name__) + def generate_variants(unresolved_spec): """Generates variants from a spec (dict) with unresolved values. @@ -55,8 +58,29 @@ def grid_search(values): return {"grid_search": values} +class sample_from(object): + """Specify that tune should sample configuration values from this function. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.eval() or tune.function(). + + Arguments: + func: An callable function to draw a sample from. + """ + + def __init__(self, func): + self.func = func + + class function(object): - """Wraps `func` to make sure it is not expanded during resolution.""" + """Wraps `func` to make sure it is not expanded during resolution. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.eval() or tune.function(). + + Arguments: + func: A function literal. + """ def __init__(self, func): self.func = func @@ -155,6 +179,11 @@ def _resolve_lambda_vars(spec, lambda_vars): value = fn(_UnresolvedAccessGuard(spec)) except RecursiveDependencyError as e: error = e + except Exception: + raise ValueError( + "Failed to evaluate expression: {}: {}".format(path, fn) + + ". If you meant to pass this as a function literal, use " + "tune.function() to escape it.") else: _assign_value(spec, path, value) resolved[path] = value @@ -198,8 +227,17 @@ def _is_resolved(v): def _try_resolve(v): if isinstance(v, types.FunctionType): - # Lambda function + logger.warn( + "Deprecation warning: Function values are ambiguous in Tune " + "configuations. Either wrap the function with " + "`tune.function(func)` to specify a function literal, or " + "`tune.sample_from(func)` to tell Tune to " + "sample values from the function during variant generation: " + "{}".format(v)) return False, v + elif isinstance(v, sample_from): + # Function to sample from + return False, v.func elif isinstance(v, dict) and len(v) == 1 and "eval" in v: # Lambda function in eval syntax return False, lambda spec: eval( diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py new file mode 100644 index 0000000000000..59f12181b8ff9 --- /dev/null +++ b/python/ray/tune/test/cluster_tests.py @@ -0,0 +1,242 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import pytest +try: + import pytest_timeout +except ImportError: + pytest_timeout = None + +import ray +from ray.rllib import _register_all +from ray.test.cluster_utils import Cluster +from ray.tune.error import TuneError +from ray.tune.trial import Trial +from ray.tune.trial_runner import TrialRunner +from ray.tune.suggest import BasicVariantGenerator + + +def _start_new_cluster(): + cluster = Cluster( + initialize_head=True, + connect=True, + head_node_args={ + "resources": dict(CPU=1), + "_internal_config": json.dumps({ + "num_heartbeats_timeout": 10 + }) + }) + # Pytest doesn't play nicely with imports + _register_all() + return cluster + + +@pytest.fixture +def start_connected_cluster(): + # Start the Ray processes. + cluster = _start_new_cluster() + yield cluster + # The code after the yield will run as teardown code. + ray.shutdown() + cluster.shutdown() + + +@pytest.fixture +def start_connected_emptyhead_cluster(): + """Starts head with no resources.""" + + cluster = Cluster( + initialize_head=True, + connect=True, + head_node_args={ + "resources": dict(CPU=0), + "_internal_config": json.dumps({ + "num_heartbeats_timeout": 10 + }) + }) + # Pytest doesn't play nicely with imports + _register_all() + yield cluster + # The code after the yield will run as teardown code. + ray.shutdown() + cluster.shutdown() + + +def test_counting_resources(start_connected_cluster): + """Tests that Tune accounting is consistent with actual cluster.""" + + cluster = start_connected_cluster + nodes = [] + assert ray.global_state.cluster_resources()["CPU"] == 1 + runner = TrialRunner(BasicVariantGenerator()) + kwargs = {"stopping_criterion": {"training_iteration": 10}} + + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() # run 1 + nodes += [cluster.add_node(resources=dict(CPU=1))] + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 2 + cluster.remove_node(nodes.pop()) + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 1 + runner.step() # run 2 + assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1 + + for i in range(5): + nodes += [cluster.add_node(resources=dict(CPU=1))] + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 6 + + runner.step() # 1 result + assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 + + +@pytest.mark.skip("Add this test once reconstruction is fixed") +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test.") +@pytest.mark.timeout(10, method="thread") +def test_remove_node_before_result(start_connected_cluster): + """Removing a node should cause a Trial to be requeued.""" + cluster = start_connected_cluster + node = cluster.add_node(resources=dict(CPU=1)) + # TODO(rliaw): Make blocking an option? + assert cluster.wait_for_nodes() + + runner = TrialRunner(BasicVariantGenerator()) + kwargs = {"stopping_criterion": {"training_iteration": 3}} + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() # run 1 + runner.step() # run 2 + assert all(t.status == Trial.RUNNING for t in trials) + + runner.step() # 1 result + + cluster.remove_node(node) + cluster.wait_for_nodes() + assert ray.global_state.cluster_resources["CPU"] == 1 + + runner.step() # recover + for i in range(5): + runner.step() + assert all(t.status == Trial.TERMINATED for t in trials) + + with pytest.raises(TuneError): + runner.step() + + +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test.") +@pytest.mark.timeout(120, method="thread") +def test_trial_migration(start_connected_emptyhead_cluster): + """Removing a node while cluster has space should migrate trial. + + The trial state should also be consistent with the checkpoint. + """ + cluster = start_connected_emptyhead_cluster + node = cluster.add_node(resources=dict(CPU=1)) + assert cluster.wait_for_nodes() + + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 3 + }, + "checkpoint_freq": 2, + "max_failures": 2 + } + + # Test recovery of trial that hasn't been checkpointed + t = Trial("__fake", **kwargs) + runner.add_trial(t) + runner.step() # start + runner.step() # 1 result + assert t.last_result is not None + node2 = cluster.add_node(resources=dict(CPU=1)) + cluster.remove_node(node) + assert cluster.wait_for_nodes() + runner.step() # Recovery step + + # TODO(rliaw): This assertion is not critical but will not pass + # because checkpoint handling is messy and should be refactored + # rather than hotfixed. + # assert t.last_result is None, "Trial result not restored correctly." + for i in range(3): + runner.step() + + assert t.status == Trial.TERMINATED + + # Test recovery of trial that has been checkpointed + t2 = Trial("__fake", **kwargs) + runner.add_trial(t2) + runner.step() # start + runner.step() # 1 result + runner.step() # 2 result and checkpoint + assert t2.has_checkpoint() + node3 = cluster.add_node(resources=dict(CPU=1)) + cluster.remove_node(node2) + assert cluster.wait_for_nodes() + runner.step() # Recovery step + assert t2.last_result["training_iteration"] == 2 + for i in range(1): + runner.step() + + assert t2.status == Trial.TERMINATED + + # Test recovery of trial that won't be checkpointed + t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}}) + runner.add_trial(t3) + runner.step() # start + runner.step() # 1 result + cluster.add_node(resources=dict(CPU=1)) + cluster.remove_node(node3) + assert cluster.wait_for_nodes() + runner.step() # Error handling step + assert t3.status == Trial.ERROR + + with pytest.raises(TuneError): + runner.step() + + +@pytest.mark.skipif( + pytest_timeout is None, + reason="Timeout package not installed; skipping test.") +@pytest.mark.timeout(120, method="thread") +def test_trial_requeue(start_connected_emptyhead_cluster): + """Removing a node in full cluster causes Trial to be requeued.""" + cluster = start_connected_emptyhead_cluster + node = cluster.add_node(resources=dict(CPU=1)) + assert cluster.wait_for_nodes() + + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 5 + }, + "checkpoint_freq": 1, + "max_failures": 1 + } + + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() # start + runner.step() # 1 result + + cluster.remove_node(node) + assert cluster.wait_for_nodes() + runner.step() + assert all(t.status == Trial.PENDING for t in trials) + + with pytest.raises(TuneError): + runner.step() diff --git a/python/ray/tune/test/ray_trial_executor_test.py b/python/ray/tune/test/ray_trial_executor_test.py index 35c413e717bb4..86c4bb189595f 100644 --- a/python/ray/tune/test/ray_trial_executor_test.py +++ b/python/ray/tune/test/ray_trial_executor_test.py @@ -9,8 +9,9 @@ from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor +from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.suggest import BasicVariantGenerator -from ray.tune.trial import Trial, Checkpoint +from ray.tune.trial import Trial, Checkpoint, Resources class RayTrialExecutorTest(unittest.TestCase): @@ -50,6 +51,12 @@ def testPauseResume(self): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testStartFailure(self): + _global_registry.register(TRAINABLE_CLASS, "asdf", None) + trial = Trial("asdf", resources=Resources(1, 0)) + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.ERROR, trial.status) + def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 1e4c0509dc151..8e4aa2cea1481 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -3,6 +3,7 @@ from __future__ import print_function import os +import sys import time import unittest @@ -14,15 +15,22 @@ from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS -from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE +from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, + EPISODES_TOTAL) from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment from ray.tune.trial import Trial, Resources from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import grid_search, BasicVariantGenerator -from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm +from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, + SuggestionAlgorithm) from ray.tune.suggest.variant_generator import RecursiveDependencyError +if sys.version_info >= (3, 3): + from unittest.mock import patch +else: + from mock import patch + class TrainableFunctionApiTest(unittest.TestCase): def setUp(self): @@ -106,7 +114,7 @@ def default_resource_request(cls, config): return Resources(cpu=config["cpu"], gpu=config["gpu"]) def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} register_trainable("B", B) @@ -184,6 +192,21 @@ def train(config, reporter): } }) + def testUploadDirNone(self): + def train(config, reporter): + reporter(timesteps_total=1) + + [trial] = run_experiments({ + "foo": { + "run": train, + "upload_dir": None, + "config": { + "a": "b" + }, + } + }) + self.assertFalse(trial.upload_dir) + def testLogdirStartingWithTilde(self): local_dir = '~/ray_results/local_dir' @@ -418,10 +441,25 @@ def train(config, reporter): }) self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL]) - def train3(config, reporter): + def train2(config, reporter): for i in range(10): reporter(timesteps_total=5) + [trial2] = run_experiments({ + "foo": { + "run": train2, + "config": { + "script_min_iter_time_s": 0, + }, + } + }) + self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5) + self.assertEqual(trial2.last_result["timesteps_this_iter"], 0) + + def train3(config, reporter): + for i in range(10): + reporter(timesteps_this_iter=0, episodes_this_iter=0) + [trial3] = run_experiments({ "foo": { "run": train3, @@ -430,8 +468,73 @@ def train3(config, reporter): }, } }) - self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5) - self.assertEqual(trial3.last_result["timesteps_this_iter"], 0) + self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0) + self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0) + + def testCheckpointDict(self): + class TestTrain(Trainable): + def _setup(self, config): + self.state = {"hi": 1} + + def _train(self): + return {"timesteps_this_iter": 1, "done": True} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + test_trainable = TestTrain() + result = test_trainable.save() + test_trainable.state["hi"] = 2 + test_trainable.restore(result) + self.assertEqual(test_trainable.state["hi"], 1) + + trials = run_experiments({ + "foo": { + "run": TestTrain, + "checkpoint_at_end": True + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + + def testMultipleCheckpoints(self): + class TestTrain(Trainable): + def _setup(self, config): + self.state = {"hi": 1, "iter": 0} + + def _train(self): + self.state["iter"] += 1 + return {"timesteps_this_iter": 1, "done": True} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + test_trainable = TestTrain() + checkpoint_1 = test_trainable.save() + test_trainable.train() + checkpoint_2 = test_trainable.save() + self.assertNotEqual(checkpoint_1, checkpoint_2) + test_trainable.restore(checkpoint_2) + self.assertEqual(test_trainable.state["iter"], 1) + test_trainable.restore(checkpoint_1) + self.assertEqual(test_trainable.state["iter"], 0) + + trials = run_experiments({ + "foo": { + "run": TestTrain, + "checkpoint_at_end": True + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) class RunExperimentTest(unittest.TestCase): @@ -538,7 +641,7 @@ def train(config, reporter): class B(Trainable): def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} register_trainable("f1", train) trials = run_experiments({ @@ -558,10 +661,13 @@ def _train(self): def testCheckpointAtEnd(self): class train(Trainable): def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} def _save(self, path): - return path + checkpoint = path + "/checkpoint" + with open(checkpoint, "w") as f: + f.write("OK") + return checkpoint trials = run_experiments({ "foo": { @@ -745,6 +851,25 @@ def testMaxConcurrentSuggestions(self): self.assertEqual(len(searcher.next_trials()), 0) +def create_mock_components(): + class _MockScheduler(FIFOScheduler): + errored_trials = [] + + def on_trial_error(self, trial_runner, trial): + self.errored_trials += [trial] + + class _MockSearchAlg(BasicVariantGenerator): + errored_trials = [] + + def on_trial_complete(self, trial_id, error=False, **kwargs): + if error: + self.errored_trials += [trial_id] + + searchalg = _MockSearchAlg() + scheduler = _MockScheduler() + return searchalg, scheduler + + class TrialRunnerTest(unittest.TestCase): def tearDown(self): ray.shutdown() @@ -789,16 +914,6 @@ def train(config, reporter): self.assertLessEqual(len(trial.logdir), 200) trial_executor.stop_trial(trial) - def testTrialErrorOnStart(self): - ray.init() - trial_executor = RayTrialExecutor() - _global_registry.register(TRAINABLE_CLASS, "asdf", None) - trial = Trial("asdf", resources=Resources(1, 0)) - try: - trial_executor.start_trial(trial) - except Exception as e: - self.assertIn("a class", str(e)) - def testExtraResources(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) @@ -821,7 +936,7 @@ def testExtraResources(self): self.assertEqual(trials[1].status, Trial.PENDING) def testFractionalGpus(self): - ray.init(num_cpus=4, num_gpus=1, use_raylet=True) + ray.init(num_cpus=4, num_gpus=1) runner = TrialRunner(BasicVariantGenerator()) kwargs = { "resources": Resources(cpu=1, gpu=0.5), @@ -901,6 +1016,30 @@ def testMultiStepRun(self): self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[1].status, Trial.RUNNING) + def testMultiStepRun2(self): + """Checks that runner.step throws when overstepping.""" + ray.init(num_cpus=1) + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "resources": Resources(cpu=1, gpu=0), + } + trials = [Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + runner.step() + self.assertEqual(trials[0].status, Trial.TERMINATED) + self.assertRaises(TuneError, runner.step) + def testErrorHandling(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) @@ -923,9 +1062,17 @@ def testErrorHandling(self): self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[1].status, Trial.RUNNING) - def testFailureRecoveryDisabled(self): + def testThrowOnOverstep(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner(BasicVariantGenerator()) + runner.step() + self.assertRaises(TuneError, runner.step) + + def testFailureRecoveryDisabled(self): + ray.init(num_cpus=1, num_gpus=1) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, @@ -944,10 +1091,15 @@ def testFailureRecoveryDisabled(self): runner.step() self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 1) + self.assertEqual(len(searchalg.errored_trials), 1) + self.assertEqual(len(scheduler.errored_trials), 1) def testFailureRecoveryEnabled(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) + kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, @@ -968,6 +1120,40 @@ def testFailureRecoveryEnabled(self): self.assertEqual(trials[0].num_failures, 1) runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(len(searchalg.errored_trials), 0) + self.assertEqual(len(scheduler.errored_trials), 0) + + def testFailureRecoveryNodeRemoval(self): + ray.init(num_cpus=1, num_gpus=1) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) + + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 1, + "config": { + "mock_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + with patch('ray.global_state.cluster_resources') as resource_mock: + resource_mock.return_value = {"CPU": 1, "GPU": 1} + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + # Mimic a node failure + resource_mock.return_value = {"CPU": 0, "GPU": 0} + runner.step() + self.assertEqual(trials[0].status, Trial.PENDING) + self.assertEqual(trials[0].num_failures, 1) + self.assertEqual(len(searchalg.errored_trials), 0) + self.assertEqual(len(scheduler.errored_trials), 1) def testFailureRecoveryMaxFailures(self): ray.init(num_cpus=1, num_gpus=1) @@ -1320,6 +1506,55 @@ def testSearchAlgStalled(self): self.assertTrue(searcher.is_finished()) self.assertTrue(runner.is_finished()) + def testSearchAlgFinishes(self): + """Empty SearchAlg changing state in `next_trials` does not crash.""" + + class FinishFastAlg(SuggestionAlgorithm): + _index = 0 + + def next_trials(self): + trials = [] + self._index += 1 + + for trial in self._trial_generator: + trials += [trial] + break + + if self._index > 4: + self._finished = True + return trials + + def _suggest(self, trial_id): + return {} + + ray.init(num_cpus=2) + experiment_spec = { + "run": "__fake", + "num_samples": 2, + "stop": { + "training_iteration": 1 + } + } + searcher = FinishFastAlg() + experiments = [Experiment.from_json("test", experiment_spec)] + searcher.add_configurations(experiments) + + runner = TrialRunner(search_alg=searcher) + self.assertFalse(runner.is_finished()) + runner.step() # This launches a new run + runner.step() # This launches a 2nd run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # This kills the first run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # This kills the 2nd run + self.assertFalse(searcher.is_finished()) + self.assertFalse(runner.is_finished()) + runner.step() # this converts self._finished to True + self.assertTrue(searcher.is_finished()) + self.assertRaises(TuneError, runner.step) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 1e537d26d953d..5824c5221ff5b 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -4,6 +4,7 @@ from datetime import datetime +import copy import gzip import io import logging @@ -83,7 +84,7 @@ def __init__(self, config=None, logger_creator=None): self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False - self._setup() + self._setup(copy.deepcopy(self.config)) self._local_ip = ray.services.get_node_ip_address() @classmethod @@ -143,6 +144,8 @@ def train(self): start = time.time() result = self._train() + assert isinstance(result, dict), "_train() needs to return a dict." + result = result.copy() self._iteration += 1 @@ -158,14 +161,14 @@ def train(self): result.setdefault(DONE, False) # self._timesteps_total should only be tracked if increments provided - if result.get(TIMESTEPS_THIS_ITER): + if result.get(TIMESTEPS_THIS_ITER) is not None: if self._timesteps_total is None: self._timesteps_total = 0 self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] - # self._timesteps_total should only be tracked if increments provided - if result.get(EPISODES_THIS_ITER): + # self._episodes_total should only be tracked if increments provided + if result.get(EPISODES_THIS_ITER) is not None: if self._episodes_total is None: self._episodes_total = 0 self._episodes_total += result[EPISODES_THIS_ITER] @@ -211,11 +214,38 @@ def save(self, checkpoint_dir=None): Checkpoint path that may be passed to restore(). """ - checkpoint_path = self._save(checkpoint_dir or self.logdir) - pickle.dump([ - self._experiment_id, self._iteration, self._timesteps_total, - self._time_total, self._episodes_total - ], open(checkpoint_path + ".tune_metadata", "wb")) + checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, + "checkpoint_{}".format(self._iteration)) + os.makedirs(checkpoint_dir) + checkpoint = self._save(checkpoint_dir) + saved_as_dict = False + if isinstance(checkpoint, str): + if (not checkpoint.startswith(checkpoint_dir) + or checkpoint == checkpoint_dir): + raise ValueError( + "The returned checkpoint path must be within the " + "given checkpoint dir {}: {}".format( + checkpoint_dir, checkpoint)) + if not os.path.exists(checkpoint): + raise ValueError( + "The returned checkpoint path does not exist: {}".format( + checkpoint)) + checkpoint_path = checkpoint + elif isinstance(checkpoint, dict): + saved_as_dict = True + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + with open(checkpoint_path, "wb") as f: + pickle.dump(checkpoint, f) + else: + raise ValueError("Return value from `_save` must be dict or str.") + pickle.dump({ + "experiment_id": self._experiment_id, + "iteration": self._iteration, + "timesteps_total": self._timesteps_total, + "time_total": self._time_total, + "episodes_total": self._episodes_total, + "saved_as_dict": saved_as_dict + }, open(checkpoint_path + ".tune_metadata", "wb")) return checkpoint_path def save_to_object(self): @@ -259,13 +289,19 @@ def restore(self, checkpoint_path): This method restores additional metadata saved with the checkpoint. """ - self._restore(checkpoint_path) metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb")) - self._experiment_id = metadata[0] - self._iteration = metadata[1] - self._timesteps_total = metadata[2] - self._time_total = metadata[3] - self._episodes_total = metadata[4] + self._experiment_id = metadata["experiment_id"] + self._iteration = metadata["iteration"] + self._timesteps_total = metadata["timesteps_total"] + self._time_total = metadata["time_total"] + self._episodes_total = metadata["episodes_total"] + saved_as_dict = metadata["saved_as_dict"] + if saved_as_dict: + with open(checkpoint_path, "rb") as loaded_state: + checkpoint_dict = pickle.load(loaded_state) + self._restore(checkpoint_dict) + else: + self._restore(checkpoint_path) self._restored = True def restore_from_object(self, obj): @@ -318,30 +354,39 @@ def _save(self, checkpoint_dir): Args: checkpoint_dir (str): The directory where the checkpoint - can be stored. + file must be stored. Returns: - Checkpoint path that may be passed to restore(). Typically - would default to `checkpoint_dir`. + checkpoint (str | dict): If string, the return value is + expected to be the checkpoint path that will be passed to + `_restore()`. If dict, the return value will be automatically + serialized by Tune and passed to `_restore()`. + + Examples: + >>> print(trainable1._save("/tmp/checkpoint_1")) + "/tmp/checkpoint_1/my_checkpoint_file" + >>> print(trainable2._save("/tmp/checkpoint_2")) + {"some": "data"} """ raise NotImplementedError - def _restore(self, checkpoint_path): + def _restore(self, checkpoint): """Subclasses should override this to implement restore(). Args: - checkpoint_path (str): The directory where the checkpoint - is stored. + checkpoint (str | dict): Value as returned by `_save`. + If a string, then it is the checkpoint path. """ raise NotImplementedError - def _setup(self): + def _setup(self, config): """Subclasses should override this for custom initialization. - Subclasses can access the hyperparameter configuration via - ``self.config``. + Args: + config (dict): Hyperparameters and other configs given. + Copy of `self.config`. """ pass diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 98fcbc6d55e60..f60fd25f2dbac 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,6 +8,7 @@ import time import tempfile import os +from numbers import Number import ray from ray.tune import TuneError @@ -33,12 +34,14 @@ class Resources( namedtuple("Resources", ["cpu", "gpu", "extra_cpu", "extra_gpu"])): """Ray resources required to schedule a trial. + TODO: Custom resources. + Attributes: - cpu (int): Number of CPUs to allocate to the trial. - gpu (int): Number of GPUs to allocate to the trial. - extra_cpu (int): Extra CPUs to reserve in case the trial needs to + cpu (float): Number of CPUs to allocate to the trial. + gpu (float): Number of GPUs to allocate to the trial. + extra_cpu (float): Extra CPUs to reserve in case the trial needs to launch additional Ray actors that use CPUs. - extra_gpu (int): Extra GPUs to reserve in case the trial needs to + extra_gpu (float): Extra GPUs to reserve in case the trial needs to launch additional Ray actors that use GPUs. """ @@ -46,6 +49,9 @@ class Resources( __slots__ = () def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0): + for entry in [cpu, gpu, extra_cpu, extra_gpu]: + assert isinstance(entry, Number), "Improper resource value." + assert entry >= 0, "Resource cannot be negative." return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu, extra_gpu) @@ -79,9 +85,10 @@ class Checkpoint(object): MEMORY = "memory" DISK = "disk" - def __init__(self, storage, value): + def __init__(self, storage, value, last_result=None): self.storage = storage self.value = value + self.last_result = last_result @staticmethod def from_object(value=None): @@ -209,17 +216,19 @@ def should_stop(self, result): return False - def should_checkpoint(self, result): + def should_checkpoint(self): """Whether this trial is due for checkpointing.""" + result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True - if not self.checkpoint_freq: + if self.checkpoint_freq: + return result.get(TRAINING_ITERATION, + 0) % self.checkpoint_freq == 0 + else: return False - return self.last_result[TRAINING_ITERATION] % self.checkpoint_freq == 0 - def progress_string(self): """Returns a progress message for printing out to the console.""" @@ -271,6 +280,16 @@ def _status_string(self): def has_checkpoint(self): return self._checkpoint.value is not None + def should_recover(self): + """Returns whether the trial qualifies for restoring. + + This is if a checkpoint frequency is set and has not failed more than + max_failures. This may return true even when there may not yet + be a checkpoint. + """ + return (self.checkpoint_freq > 0 + and self.num_failures < self.max_failures) + def update_last_result(self, result, terminate=False): if terminate: result.update(done=True) @@ -299,8 +318,10 @@ def __repr__(self): def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.""" if "env" in self.config: - identifier = "{}_{}".format(self.trainable_name, - self.config["env"]) + env = self.config["env"] + if isinstance(env, type): + env = env.__name__ + identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index b961d12b8cbe4..063129780b47a 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -4,7 +4,6 @@ from __future__ import print_function import logging -import traceback from ray.tune.trial import Trial, Checkpoint @@ -33,12 +32,10 @@ def has_resources(self, resources): "has_resources() method") def start_trial(self, trial, checkpoint=None): - """Starts the trial restoring from checkpoint if checkpoint != None. - - If an error is encountered when starting the trial, an exception will - be thrown. + """Starts the trial restoring from checkpoint if checkpoint is provided. Args: + trial (Trial): Trial to be started. checkpoint(Checkpoint): A Python object or path storing the state of trial. """ @@ -60,26 +57,6 @@ def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): raise NotImplementedError("Subclasses of TrialExecutor must provide " "stop_trial() method") - def restart_trial(self, trial, error_msg=None): - """Restarts the trial. - - The state of the trial should restore from the last checkpoint. - - Args: - error_msg (str): Optional error message. - """ - try: - logger.info( - "Attempting to recover trial state from last checkpoint") - self.stop_trial( - trial, error=True, error_msg=error_msg, stop_logger=False) - trial.result_logger.flush() - self.start_trial(trial) - except Exception: - error_msg = traceback.format_exc() - logger.exception("Error recovering trial from checkpoint, abort.") - self.stop_trial(trial, error=True, error_msg=error_msg) - def continue_training(self, trial): """Continues the training of this trial.""" pass @@ -158,7 +135,11 @@ def fetch_result(self, trial): def debug_string(self): """Returns a human readable message for printing to the console.""" - pass + raise NotImplementedError + + def resource_string(self): + """Returns a string describing the total resources available.""" + raise NotImplementedError def restore(self, trial, checkpoint=None): """Restores training state from a checkpoint. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 6423d6a95f10e..84457ff8d9e95 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -12,7 +12,7 @@ from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S -from ray.tune.trial import Trial +from ray.tune.trial import Trial, Checkpoint from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.web_server import TuneServer @@ -108,6 +108,8 @@ def step(self): Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ + if self.is_finished(): + raise TuneError("Called step when all trials finished?") self.trial_executor.on_step_begin() next_trial = self._get_next_trial() if next_trial is not None: @@ -120,20 +122,19 @@ def step(self): if not self.has_resources(trial.resources): raise TuneError( ("Insufficient cluster resources to launch trial: " - "trial requested {} but the cluster summary: {} " + "trial requested {} but the cluster has only {}. " "Pass `queue_trials=True` in " "ray.tune.run_experiments() or on the command " "line to queue trials until the cluster scales " "up. {}").format( trial.resources.summary_string(), - self.trial_executor.debug_string(), + self.trial_executor.resource_string(), trial._get_trainable_cls().resource_help( trial.config))) elif trial.status == Trial.PAUSED: raise TuneError( "There are paused trials, but no more pending " "trials with sufficient resources.") - raise TuneError("Called step when all trials finished?") if self._server: self._process_requests() @@ -215,8 +216,29 @@ def _debug_messages(self): messages = ["== Status =="] messages.append(self._scheduler_alg.debug_string()) messages.append(self.trial_executor.debug_string()) + messages.append(self._memory_debug_string()) return messages + def _memory_debug_string(self): + try: + import psutil + total_gb = psutil.virtual_memory().total / 1e9 + used_gb = total_gb - psutil.virtual_memory().available / 1e9 + if used_gb > total_gb * 0.9: + warn = (": ***LOW MEMORY*** less than 10% of the memory on " + "this node is available for use. This can cause " + "unexpected crashes. Consider " + "reducing the memory used by your application " + "or reducing the Ray object store size by setting " + "`object_store_memory` when calling `ray.init`.") + else: + warn = "" + return "Memory usage on this node: {}/{} GB{}".format( + round(used_gb, 1), round(total_gb, 1), warn) + except ImportError: + return ("Unknown memory usage. Please run `pip install psutil` " + "(or ray[debug]) to resolve)") + def has_resources(self, resources): """Returns whether this runner has at least the specified resources.""" return self.trial_executor.has_resources(resources) @@ -257,17 +279,14 @@ def _process_events(self): result, terminate=(decision == TrialScheduler.STOP)) if decision == TrialScheduler.CONTINUE: - if trial.should_checkpoint(result): - # TODO(rliaw): This is a blocking call - self.trial_executor.save(trial) + self._checkpoint_if_needed(trial) self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: # Checkpoint before ending the trial # if checkpoint_at_end experiment option is set to True - if trial.should_checkpoint(result): - self.trial_executor.save(trial) + self._checkpoint_if_needed(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format( @@ -276,24 +295,61 @@ def _process_events(self): logger.exception("Error processing event.") error_msg = traceback.format_exc() if trial.status == Trial.RUNNING: - if trial.has_checkpoint() and \ - trial.num_failures < trial.max_failures: + if trial.should_recover(): self._try_recover(trial, error_msg) else: self._scheduler_alg.on_trial_error(self, trial) self._search_alg.on_trial_complete( trial.trial_id, error=True) - self.trial_executor.stop_trial(trial, True, error_msg) + self.trial_executor.stop_trial( + trial, error=True, error_msg=error_msg) + + def _checkpoint_if_needed(self, trial): + """Checkpoints trial based off trial.last_result.""" + if trial.should_checkpoint(): + # Save trial runtime if possible + if hasattr(trial, "runner") and trial.runner: + self.trial_executor.save(trial, storage=Checkpoint.DISK) def _try_recover(self, trial, error_msg): + """Tries to recover trial. + + Notifies SearchAlgorithm and Scheduler if failure to recover. + + Args: + trial (Trial): Trial to recover. + error_msg (str): Error message from prior to invoking this method. + """ try: - logger.info("Attempting to recover" - " trial state from last checkpoint.") - self.trial_executor.restart_trial(trial, error_msg) + self.trial_executor.stop_trial( + trial, + error=error_msg is not None, + error_msg=error_msg, + stop_logger=False) + trial.result_logger.flush() + if self.trial_executor.has_resources(trial.resources): + logger.info("Attempting to recover" + " trial state from last checkpoint.") + self.trial_executor.start_trial(trial) + if trial.status == Trial.ERROR: + raise RuntimeError("Trial did not start correctly.") + else: + logger.debug("Notifying Scheduler and requeueing trial.") + self._requeue_trial(trial) except Exception: - error_msg = traceback.format_exc() - logger.warning("Error recovering trial from checkpoint, abort.") - self.trial_executor.stop_trial(trial, True, error_msg=error_msg) + logger.exception("Error recovering trial from checkpoint, abort.") + self._scheduler_alg.on_trial_error(self, trial) + self._search_alg.on_trial_complete(trial.trial_id, error=True) + + def _requeue_trial(self, trial): + """Notification to TrialScheduler and requeue trial. + + This does not notify the SearchAlgorithm because + the function evaluation is still in progress. + """ + self._scheduler_alg.on_trial_error(self, trial) + trial.status = Trial.PENDING + self._scheduler_alg.on_trial_add(self, trial) def _update_trial_queue(self, blocking=False, timeout=600): """Adds next trials to queue if possible. @@ -302,13 +358,15 @@ def _update_trial_queue(self, blocking=False, timeout=600): Args: blocking (bool): Blocks until either a trial is available - or the Runner finishes (i.e., timeout or search algorithm - finishes). + or is_finished (timeout or search algorithm finishes). timeout (int): Seconds before blocking times out. """ trials = self._search_alg.next_trials() if blocking and not trials: start = time.time() + # Checking `is_finished` instead of _search_alg.is_finished + # is fine because blocking only occurs if all trials are + # finished and search_algorithm is not yet finished while (not trials and not self.is_finished() and time.time() - start < timeout): logger.info("Blocking for next trial...") diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 691d25adbe97e..9c047fd80043e 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -28,7 +28,7 @@ def pin_in_object_store(obj): def get_pinned_object(pinned_id): """Retrieve a pinned object from the object store.""" - from ray.local_scheduler import ObjectID + from ray.raylet import ObjectID return _from_pinnable( ray.get( diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f6adaea98683..e75e006721444 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -5,6 +5,7 @@ import binascii import functools import hashlib +import inspect import numpy as np import os import subprocess @@ -14,11 +15,9 @@ import uuid import ray.gcs_utils -import ray.local_scheduler +import ray.raylet import ray.ray_constants as ray_constants -ERROR_KEY_PREFIX = b"Error:" - def _random_string(): id_hash = hashlib.sha1() @@ -69,22 +68,12 @@ def push_error_to_driver(worker, """ if driver_id is None: driver_id = ray_constants.NIL_JOB_ID.id() - error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - if not worker.use_raylet: - worker.redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) - worker.redis_client.rpush("ErrorKeys", error_key) - else: - worker.local_scheduler_client.push_error( - ray.ObjectID(driver_id), error_type, message, time.time()) + worker.local_scheduler_client.push_error( + ray.ObjectID(driver_id), error_type, message, time.time()) def push_error_to_driver_through_redis(redis_client, - use_raylet, error_type, message, driver_id=None, @@ -98,8 +87,6 @@ def push_error_to_driver_through_redis(redis_client, Args: redis_client: The redis client to use. - use_raylet: True if we are using the Raylet code path and false - otherwise. error_type (str): The type of the error. message (str): The message that will be printed in the background on the driver. @@ -110,23 +97,14 @@ def push_error_to_driver_through_redis(redis_client, """ if driver_id is None: driver_id = ray_constants.NIL_JOB_ID.id() - error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - if not use_raylet: - redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) - redis_client.rpush("ErrorKeys", error_key) - else: - # Do everything in Python and through the Python Redis client instead - # of through the raylet. - error_data = ray.gcs_utils.construct_error_message( - driver_id, error_type, message, time.time()) - redis_client.execute_command( - "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) + # Do everything in Python and through the Python Redis client instead + # of through the raylet. + error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, + message, time.time()) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) def is_cython(obj): @@ -144,6 +122,23 @@ def check_cython(x): (hasattr(obj, "__func__") and check_cython(obj.__func__)) +def is_function_or_method(obj): + """Check if an object is a function or method. + + Args: + obj: The Python object in question. + + Returns: + True if the object is an function or method. + """ + return (inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)) + + +def is_class_method(f): + """Returns whether the given method is a class_method.""" + return hasattr(f, "__self__") and f.__self__ is not None + + def random_string(): """Generate a random string to use as an ID. @@ -329,6 +324,28 @@ def get_system_memory(): return sysctl(["sysctl", "hw.memsize"]) +def get_shared_memory_bytes(): + """Get the size of the shared memory file system. + + Returns: + The size of the shared memory file system in bytes. + """ + # Make sure this is only called on Linux. + assert sys.platform == "linux" or sys.platform == "linux2" + + shm_fd = os.open("/dev/shm", os.O_RDONLY) + try: + shm_fs_stats = os.fstatvfs(shm_fd) + # The value shm_fs_stats.f_bsize is the block size and the + # value shm_fs_stats.f_bavail is the number of available + # blocks. + shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail + finally: + os.close(shm_fd) + + return shm_avail + + def check_oversized_pickle(pickled, name, obj_type, worker): """Send a warning message if the pickled object is too large. @@ -406,3 +423,7 @@ def thread_safe_client(client, lock=None): if lock is None: lock = threading.Lock() return _ThreadSafeProxy(client, lock) + + +def is_main_thread(): + return threading.current_thread().getName() == "MainThread" diff --git a/python/ray/worker.py b/python/ray/worker.py index 2d1d45f65b1c3..de2513780ad5a 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2,9 +2,10 @@ from __future__ import division from __future__ import print_function +from contextlib import contextmanager import atexit -import collections import colorama +import faulthandler import hashlib import inspect import logging @@ -23,17 +24,19 @@ import ray.cloudpickle as pickle import ray.experimental.state as state import ray.gcs_utils +import ray.memory_monitor as memory_monitor import ray.remote_function import ray.serialization as serialization import ray.services as services import ray.signature -import ray.local_scheduler +import ray.tempfile_services as tempfile_services +import ray.raylet import ray.plasma import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling +from ray.function_manager import FunctionActorManager from ray.utils import ( - binary_to_hex, check_oversized_pickle, is_cython, random_string, @@ -55,14 +58,6 @@ NIL_ACTOR_HANDLE_ID = NIL_ID NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" -# This must be kept in sync with the `error_types` array in -# common/state/error_table.h. -OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch" -PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction" - -# This must be kept in sync with the `scheduling_state` enum in common/task.h. -TASK_STATUS_RUNNING = 8 - # Default resource requirements for actors when no resource requirements are # specified. DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1 @@ -77,6 +72,11 @@ # using logging.basicConfig in its entry/init points. logger = logging.getLogger(__name__) +try: + import setproctitle +except ImportError: + setproctitle = None + class RayTaskError(Exception): """An object used internally to represent a task that threw an exception. @@ -175,11 +175,6 @@ def __str__(self): self.task_error)) -FunctionExecutionInfo = collections.namedtuple( - "FunctionExecutionInfo", ["function", "function_name", "max_calls"]) -"""FunctionExecutionInfo: A named tuple storing remote function information.""" - - class Worker(object): """A class used to define the control flow of a worker process. @@ -188,19 +183,9 @@ class Worker(object): functions outside of this class are considered exposed. Attributes: - function_execution_info (Dict[str, FunctionExecutionInfo]): A - dictionary mapping the name of a remote function to the remote - function itself. This is the set of remote functions that can be - executed by this worker. connected (bool): True if Ray has been started and False otherwise. mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and WORKER_MODE. - cached_remote_functions_and_actors: A list of information for exporting - remote functions and actor classes definitions that were defined - before the worker called connect. When the worker eventually does - call connect, if it is a driver, it will export these functions and - actors. If cached_remote_functions_and_actors is None, that means - that connect has been called already. cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. profiler: the profiler used to aggregate profiling information. @@ -215,24 +200,9 @@ class Worker(object): def __init__(self): """Initialize a Worker object.""" - # This field is a dictionary that maps a driver ID to a dictionary of - # functions (and information about those functions) that have been - # registered for that driver (this inner dictionary maps function IDs - # to a FunctionExecutionInfo object. This should only be used on - # workers that execute remote functions. - self.function_execution_info = collections.defaultdict(lambda: {}) - # This is a dictionary mapping driver ID to a dictionary that maps - # remote function IDs for that driver to a counter of the number of - # times that remote function has been executed on this worker. The - # counter is incremented every time the function is executed on this - # worker. When the counter reaches the maximum number of executions - # allowed for a particular function, the worker is killed. - self.num_task_executions = collections.defaultdict(lambda: {}) self.connected = False self.mode = None - self.cached_remote_functions_and_actors = [] self.cached_functions_to_run = [] - self.fetch_and_register_actor = None self.actor_init_error = None self.make_actor = None self.actors = {} @@ -247,13 +217,44 @@ def __init__(self): # When the worker is constructed. Record the original value of the # CUDA_VISIBLE_DEVICES environment variable. self.original_gpu_ids = ray.utils.get_cuda_visible_devices() - self.profiler = profiling.Profiler(self) + self.profiler = None + self.memory_monitor = memory_monitor.MemoryMonitor() self.state_lock = threading.Lock() # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} + self.function_actor_manager = FunctionActorManager(self) + # Reads/writes to the following fields must be protected by + # self.state_lock. # Identity of the driver that this worker is processing. - self.task_driver_id = None + self.task_driver_id = ray.ObjectID(NIL_ID) + self.current_task_id = ray.ObjectID(NIL_ID) + self.task_index = 0 + self.put_index = 1 + + def get_current_thread_task_id(self): + """Get the current thread's task ID. + + This returns the assigned task ID if called on the main thread, else a + random task ID. This method is not thread-safe and must be called with + self.state_lock acquired. + """ + current_task_id = self.current_task_id + if not ray.utils.is_main_thread(): + # If this is running on a separate thread, then the mapping + # to the current task ID may not be correct. Generate a + # random task ID so that the backend can differentiate + # between different threads. + current_task_id = ray.ObjectID(random_string()) + if not self.multithreading_warned: + logger.warning( + "Calling ray.get or ray.wait in a separate thread " + "may lead to deadlock if the main thread blocks on this " + "thread and there are not enough resources to execute " + "more tasks") + self.multithreading_warned = True + assert not current_task_id.is_nil() + return current_task_id def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -349,7 +350,7 @@ def store_and_register(self, object_id, value, depth=100): "of their fields. This behavior may " "be incorrect in some cases.".format( type(e.example_object))) - logger.warning(warning_message) + logger.debug(warning_message) except (serialization.RayNotDictionarySerializable, serialization.CloudPickleError, pickle.pickle.PicklingError, Exception): @@ -413,6 +414,17 @@ def put_object(self, object_id, value): logger.info( "The object with ID {} already exists in the object store." .format(object_id)) + except TypeError: + # This error can happen because one of the members of the object + # may not be serializable for cloudpickle. So we need these extra + # fallbacks here to start from the beginning. Hopefully the object + # could have a `__reduce__` method. + register_custom_serializer(type(value), use_pickle=True) + warning_message = ("WARNING: Serializing the class {} failed, " + "so are are falling back to cloudpickle." + .format(type(value))) + logger.warning(warning_message) + self.store_and_register(object_id, value) def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): start_time = time.time() @@ -439,7 +451,8 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): invalid_error = RayTaskError( "", None, "Invalid return value: likely worker died or was killed " - "while executing the task.") + "while executing the task; check previous logs or dmesg " + "for errors.") return [invalid_error] * len(object_ids) except pyarrow.DeserializationCallbackError: # Wait a little bit for the import thread to import the class. @@ -489,13 +502,9 @@ def get_object(self, object_ids): ] for i in range(0, len(object_ids), ray._config.worker_fetch_request_size()): - if not self.use_raylet: - self.plasma_client.fetch(plain_object_ids[i:( - i + ray._config.worker_fetch_request_size())]) - else: - self.local_scheduler_client.reconstruct_objects( - object_ids[i:( - i + ray._config.worker_fetch_request_size())], True) + self.local_scheduler_client.fetch_or_reconstruct( + object_ids[i:(i + ray._config.worker_fetch_request_size())], + True) # Get the objects. We initially try to get the objects immediately. final_results = self.retrieve_and_deserialize(plain_object_ids, 0) @@ -509,6 +518,9 @@ def get_object(self, object_ids): if len(unready_ids) > 0: with self.state_lock: + # Get the task ID, to notify the backend which task is blocked. + current_task_id = self.get_current_thread_task_id() + # Try reconstructing any objects we haven't gotten yet. Try to # get them until at least get_timeout_milliseconds # milliseconds passes, then repeat. @@ -525,25 +537,10 @@ def get_object(self, object_ids): ray._config.worker_fetch_request_size()) for i in range(0, len(object_ids_to_fetch), fetch_request_size): - if not self.use_raylet: - for unready_id in ray_object_ids_to_fetch[i:( - i + fetch_request_size)]: - (self.local_scheduler_client. - reconstruct_objects([unready_id], False)) - # Do another fetch for objects that aren't - # available locally yet, in case they were evicted - # since the last fetch. We divide the fetch into - # smaller fetches so as to not block the manager - # for a prolonged period of time in a single call. - # This is only necessary for legacy ray since - # reconstruction and fetch are implemented by - # different processes. - self.plasma_client.fetch(object_ids_to_fetch[i:( - i + fetch_request_size)]) - else: - self.local_scheduler_client.reconstruct_objects( - ray_object_ids_to_fetch[i:( - i + fetch_request_size)], False) + self.local_scheduler_client.fetch_or_reconstruct( + ray_object_ids_to_fetch[i:( + i + fetch_request_size)], False, + current_task_id) results = self.retrieve_and_deserialize( object_ids_to_fetch, max([ @@ -561,7 +558,7 @@ def get_object(self, object_ids): # If there were objects that we weren't able to get locally, # let the local scheduler know that we're now unblocked. - self.local_scheduler_client.notify_unblocked() + self.local_scheduler_client.notify_unblocked(current_task_id) assert len(final_results) == len(object_ids) return final_results @@ -578,6 +575,7 @@ def submit_task(self, execution_dependencies=None, num_return_vals=None, resources=None, + placement_resources=None, driver_id=None): """Submit a remote task to the scheduler. @@ -603,6 +601,9 @@ def submit_task(self, num_return_vals: The number of return values this function should have. resources: The resource requirements for this task. + placement_resources: The resources required for placing the task. + If this is not provided or if it is an empty dictionary, then + the placement resources will be equal to resources. driver_id: The ID of the relevant driver. This is almost always the driver ID of the driver that is currently running. However, in the exceptional case that an actor task is being dispatched to @@ -632,7 +633,7 @@ def submit_task(self, for arg in args: if isinstance(arg, ray.ObjectID): args_for_local_scheduler.append(arg) - elif ray.local_scheduler.check_simple_value(arg): + elif ray.raylet.check_simple_value(arg): args_for_local_scheduler.append(arg) else: args_for_local_scheduler.append(put(arg)) @@ -656,74 +657,28 @@ def submit_task(self, raise ValueError( "Resource quantities must all be whole numbers.") + if placement_resources is None: + placement_resources = {} + with self.state_lock: # Increment the worker's task index to track how many tasks # have been submitted by the current task so far. task_index = self.task_index self.task_index += 1 + # The parent task must be set for the submitted task. + assert not self.current_task_id.is_nil() # Submit the task to local scheduler. - task = ray.local_scheduler.Task( + task = ray.raylet.Task( driver_id, ray.ObjectID( function_id.id()), args_for_local_scheduler, num_return_vals, self.current_task_id, task_index, actor_creation_id, actor_creation_dummy_object_id, actor_id, - actor_handle_id, actor_counter, is_actor_checkpoint_method, - execution_dependencies, resources, self.use_raylet) + actor_handle_id, actor_counter, execution_dependencies, + resources, placement_resources) self.local_scheduler_client.submit(task) return task.returns() - def export_remote_function(self, function_id, function_name, function, - max_calls, decorated_function): - """Export a remote function. - - Args: - function_id: The ID of the function. - function_name: The name of the function. - function: The raw undecorated function to export. - max_calls: The maximum number of times a given worker can execute - this function before exiting. - decorated_function: The decorated function (this is used to enable - the remote function to recursively call itself). - """ - if self.mode != SCRIPT_MODE: - raise Exception("export_remote_function can only be called on a " - "driver.") - - key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" + - function_id.id()) - - # Work around limitations of Python pickling. - function_name_global_valid = function.__name__ in function.__globals__ - function_name_global_value = function.__globals__.get( - function.__name__) - # Allow the function to reference itself as a global variable - if not is_cython(function): - function.__globals__[function.__name__] = decorated_function - try: - pickled_function = pickle.dumps(function) - finally: - # Undo our changes - if function_name_global_valid: - function.__globals__[function.__name__] = ( - function_name_global_value) - else: - del function.__globals__[function.__name__] - - check_oversized_pickle(pickled_function, function_name, - "remote function", self) - - self.redis_client.hmset( - key, { - "driver_id": self.task_driver_id.id(), - "function_id": function_id.id(), - "name": function_name, - "module": function.__module__, - "function": pickled_function, - "max_calls": max_calls - }) - self.redis_client.rpush("Exports", key) - def run_function_on_all_workers(self, function, run_on_other_drivers=False): """Run arbitrary code on all of the workers. @@ -773,7 +728,7 @@ def run_function_on_all_workers(self, function, "driver_id": self.task_driver_id.id(), "function_id": function_to_run_id, "function": pickled_function, - "run_on_other_drivers": run_on_other_drivers + "run_on_other_drivers": str(run_on_other_drivers) }) self.redis_client.rpush("Exports", key) # TODO(rkn): If the worker fails after it calls setnx and before it @@ -782,47 +737,6 @@ def run_function_on_all_workers(self, function, # operations into a transaction (or by implementing a custom # command that does all three things). - def _wait_for_function(self, function_id, driver_id, timeout=10): - """Wait until the function to be executed is present on this worker. - - This method will simply loop until the import thread has imported the - relevant function. If we spend too long in this loop, that may indicate - a problem somewhere and we will push an error message to the user. - - If this worker is an actor, then this will wait until the actor has - been defined. - - Args: - function_id (str): The ID of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to - if this times out. - """ - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - with self.lock: - if (self.actor_id == NIL_ACTOR_ID - and (function_id.id() in - self.function_execution_info[driver_id])): - break - elif self.actor_id != NIL_ACTOR_ID and ( - self.actor_id in self.actors): - break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart " - "Ray.") - if not warning_sent: - ray.utils.push_error_to_driver( - self, - ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, - warning_message, - driver_id=driver_id) - warning_sent = True - time.sleep(0.001) - def _get_arguments_for_execution(self, function_name, serialized_args): """Retrieve the arguments for the remote function. @@ -863,7 +777,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args): arguments.append(argument) return arguments - def _store_outputs_in_objstore(self, object_ids, outputs): + def _store_outputs_in_object_store(self, object_ids, outputs): """Store the outputs of a remote function in the local object store. This stores the values that were returned by a remote function in the @@ -890,7 +804,7 @@ def _store_outputs_in_objstore(self, object_ids, outputs): self.put_object(object_ids[i], outputs[i]) - def _process_task(self, task): + def _process_task(self, task, function_execution_info): """Execute a task assigned to this worker. This method deserializes a task from the scheduler, and attempts to @@ -900,37 +814,41 @@ def _process_task(self, task): (these will be retrieved by calls to get or by subsequent tasks that use the outputs of this task). """ - # The ID of the driver that this task belongs to. This is needed so - # that if the task throws an exception, we propagate the error - # message to the correct driver. - self.task_driver_id = task.driver_id() - self.current_task_id = task.task_id() - self.task_index = 0 - self.put_index = 1 + with self.state_lock: + assert self.task_driver_id.is_nil() + assert self.current_task_id.is_nil() + assert self.task_index == 0 + assert self.put_index == 1 + + # The ID of the driver that this task belongs to. This is needed so + # that if the task throws an exception, we propagate the error + # message to the correct driver. + self.task_driver_id = task.driver_id() + self.current_task_id = task.task_id() + function_id = task.function_id() args = task.arguments() return_object_ids = task.returns() if task.actor_id().id() != NIL_ACTOR_ID: dummy_return_id = return_object_ids.pop() - function_executor = self.function_execution_info[ - self.task_driver_id.id()][function_id.id()].function - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + function_executor = function_execution_info.function + function_name = function_execution_info.function_name # Get task arguments from the object store. try: if function_name != "__ray_terminate__": self.reraise_actor_init_error() + self.memory_monitor.raise_if_low_memory() with profiling.profile("task:deserialize_arguments", worker=self): arguments = self._get_arguments_for_execution( function_name, args) except (RayGetError, RayGetArgumentError) as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, None) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, None) return except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return @@ -949,8 +867,9 @@ def _process_task(self, task): task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, return_object_ids, - e, traceback_str) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, + traceback_str) return # Store the outputs in the local object store. @@ -962,21 +881,19 @@ def _process_task(self, task): num_returns = len(return_object_ids) if num_returns == 1: outputs = (outputs, ) - self._store_outputs_in_objstore(return_object_ids, outputs) + self._store_outputs_in_object_store(return_object_ids, outputs) except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) - def _handle_process_task_failure(self, function_id, return_object_ids, - error, backtrace): - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + def _handle_process_task_failure(self, function_id, function_name, + return_object_ids, error, backtrace): failure_object = RayTaskError(function_name, error, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) ] - self._store_outputs_in_objstore(return_object_ids, failure_objects) + self._store_outputs_in_object_store(return_object_ids, failure_objects) # Log the error message. ray.utils.push_error_to_driver( self, @@ -1013,7 +930,7 @@ def _become_actor(self, task): time.sleep(0.001) with self.lock: - self.fetch_and_register_actor(key, self) + self.function_actor_manager.fetch_and_register_actor(key) def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -1030,11 +947,8 @@ def _wait_for_and_process_task(self, task): self._become_actor(task) return - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with profiling.profile("wait_for_function", worker=self): - self._wait_for_function(function_id, driver_id) + execution_info = self.function_actor_manager.get_execution_info( + driver_id, function_id) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -1042,38 +956,38 @@ def _wait_for_and_process_task(self, task): # because that may indicate that the system is hanging, and it'd be # good to know where the system is hanging. with self.lock: - - function_name = (self.function_execution_info[driver_id][ - function_id.id()]).function_name - if not self.use_raylet: - extra_data = { - "function_name": function_name, - "task_id": task.task_id().hex(), - "worker_id": binary_to_hex(self.worker_id) - } + function_name = execution_info.function_name + extra_data = { + "name": function_name, + "task_id": task.task_id().hex() + } + if task.actor_id().id() == NIL_ACTOR_ID: + title = "ray_worker:{}()".format(function_name) + next_title = "ray_worker" else: - extra_data = { - "name": function_name, - "task_id": task.task_id().hex() - } + actor = self.actors[task.actor_id().id()] + title = "ray_{}:{}()".format(actor.__class__.__name__, + function_name) + next_title = "ray_{}".format(actor.__class__.__name__) with profiling.profile("task", extra_data=extra_data, worker=self): - self._process_task(task) - - # In the non-raylet code path, push all of the log events to the global - # state store. In the raylet code path, this is done periodically in a - # background thread. - if not self.use_raylet: - self.profiler.flush_profile_data() + with _changeproctitle(title, next_title): + self._process_task(task, execution_info) + # Reset the state fields so the next task can run. + with self.state_lock: + self.task_driver_id = ray.ObjectID(NIL_ID) + self.current_task_id = ray.ObjectID(NIL_ID) + self.task_index = 0 + self.put_index = 1 # Increase the task execution counter. - self.num_task_executions[driver_id][function_id.id()] += 1 + self.function_actor_manager.increase_task_counter( + driver_id, function_id.id()) - reached_max_executions = ( - self.num_task_executions[driver_id][function_id.id()] == self. - function_execution_info[driver_id][function_id.id()].max_calls) + reached_max_executions = (self.function_actor_manager.get_task_counter( + driver_id, function_id.id()) == execution_info.max_calls) if reached_max_executions: self.local_scheduler_client.disconnect() - os._exit(0) + sys.exit(0) def _get_next_task_from_local_scheduler(self): """Get the next task from the local scheduler. @@ -1081,7 +995,7 @@ def _get_next_task_from_local_scheduler(self): Returns: A task from the local scheduler. """ - with profiling.profile("get_task", worker=self): + with profiling.profile("worker_idle", worker=self): task = self.local_scheduler_client.get_task() # Automatically restrict the GPUs available to this task. @@ -1118,13 +1032,10 @@ def get_gpu_ids(): raise Exception("ray.get_gpu_ids() currently does not work in PYTHON " "MODE.") - if not global_worker.use_raylet: - assigned_ids = global_worker.local_scheduler_client.gpu_ids() - else: - all_resource_ids = global_worker.local_scheduler_client.resource_ids() - assigned_ids = [ - resource_id for resource_id, _ in all_resource_ids.get("GPU", []) - ] + all_resource_ids = global_worker.local_scheduler_client.resource_ids() + assigned_ids = [ + resource_id for resource_id, _ in all_resource_ids.get("GPU", []) + ] # If the user had already set CUDA_VISIBLE_DEVICES, then respect that (in # the sense that only GPU IDs that appear in CUDA_VISIBLE_DEVICES should be # returned). @@ -1139,17 +1050,11 @@ def get_gpu_ids(): def get_resource_ids(): """Get the IDs of the resources that are available to the worker. - This function is only supported in the raylet code path. - Returns: A dictionary mapping the name of a resource to a list of pairs, where each pair consists of the ID of a resource and the fraction of that resource reserved for this worker. """ - if not global_worker.use_raylet: - raise Exception("ray.get_resource_ids() is only supported in the " - "raylet code path.") - if _mode() == LOCAL_MODE: raise Exception( "ray.get_resource_ids() currently does not work in PYTHON " @@ -1232,22 +1137,8 @@ def error_applies_to_driver(error_key, worker=global_worker): def error_info(worker=global_worker): """Return information about failed tasks.""" worker.check_connected() - if worker.use_raylet: - return (global_state.error_messages(job_id=worker.task_driver_id) + - global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) - error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) - errors = [] - for error_key in error_keys: - if error_applies_to_driver(error_key, worker=worker): - error_contents = worker.redis_client.hgetall(error_key) - error_contents = { - "type": ray.utils.decode(error_contents[b"type"]), - "message": ray.utils.decode(error_contents[b"message"]), - "data": ray.utils.decode(error_contents[b"data"]) - } - errors.append(error_contents) - - return errors + return (global_state.error_messages(job_id=worker.task_driver_id) + + global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) def _initialize_serialization(driver_id, worker=global_worker): @@ -1343,121 +1234,57 @@ def actor_handle_deserializer(serialized_obj): def get_address_info_from_redis_helper(redis_address, node_ip_address, - use_raylet=False): + redis_password=None): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as # Redis) must have run "CONFIG SET protected-mode no". redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) - - if not use_raylet: - # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. - client_keys = redis_client.keys("{}*".format( - ray.gcs_utils.DB_CLIENT_PREFIX)) - # Filter to live clients on the same node and do some basic checking. - plasma_managers = [] - local_schedulers = [] - for key in client_keys: - info = redis_client.hgetall(key) - - # Ignore clients that were deleted. - deleted = info[b"deleted"] - deleted = bool(int(deleted)) - if deleted: - continue - - assert b"ray_client_id" in info - assert b"node_ip_address" in info - assert b"client_type" in info - client_node_ip_address = ray.utils.decode(info[b"node_ip_address"]) - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - if ray.utils.decode(info[b"client_type"]) == "plasma_manager": - plasma_managers.append(info) - elif (ray.utils.decode( - info[b"client_type"]) == "local_scheduler"): - local_schedulers.append(info) - # Make sure that we got at least one plasma manager and local - # scheduler. - assert len(plasma_managers) >= 1 - assert len(local_schedulers) >= 1 - # Build the address information. - object_store_addresses = [] - for manager in plasma_managers: - address = ray.utils.decode(manager[b"manager_address"]) - port = services.get_port(address) - object_store_addresses.append( - services.ObjectStoreAddress( - name=ray.utils.decode(manager[b"store_socket_name"]), - manager_name=ray.utils.decode( - manager[b"manager_socket_name"]), - manager_port=port)) - scheduler_names = [ - ray.utils.decode(scheduler[b"local_scheduler_socket_name"]) - for scheduler in local_schedulers - ] - client_info = { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_addresses": object_store_addresses, - "local_scheduler_socket_names": scheduler_names, - # Web UI should be running. - "webui_url": _webui_url_helper(redis_client) - } - return client_info - - # Handle the raylet case. - else: - # In the raylet code path, all client data is stored in a zset at the - # key for the nil client. - client_key = b"CLIENT" + NIL_CLIENT_ID - clients = redis_client.zrange(client_key, 0, -1) - raylets = [] - for client_message in clients: - client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - client_message, 0) - client_node_ip_address = ray.utils.decode( - client.NodeManagerAddress()) - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - raylets.append(client) - # Make sure that at least one raylet has started locally. - # This handles a race condition where Redis has started but - # the raylet has not connected. - if len(raylets) == 0: - raise Exception( - "Redis has started but no raylets have registered yet.") - object_store_addresses = [ - services.ObjectStoreAddress( - name=ray.utils.decode(raylet.ObjectStoreSocketName()), - manager_name=None, - manager_port=None) for raylet in raylets - ] - raylet_socket_names = [ - ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets - ] - return { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_addresses": object_store_addresses, - "raylet_socket_names": raylet_socket_names, - # Web UI should be running. - "webui_url": _webui_url_helper(redis_client) - } + host=redis_ip_address, port=int(redis_port), password=redis_password) + + # In the raylet code path, all client data is stored in a zset at the + # key for the nil client. + client_key = b"CLIENT" + NIL_CLIENT_ID + clients = redis_client.zrange(client_key, 0, -1) + raylets = [] + for client_message in clients: + client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + client_message, 0) + client_node_ip_address = ray.utils.decode(client.NodeManagerAddress()) + if (client_node_ip_address == node_ip_address or + (client_node_ip_address == "127.0.0.1" + and redis_ip_address == ray.services.get_node_ip_address())): + raylets.append(client) + # Make sure that at least one raylet has started locally. + # This handles a race condition where Redis has started but + # the raylet has not connected. + if len(raylets) == 0: + raise Exception( + "Redis has started but no raylets have registered yet.") + object_store_addresses = [ + ray.utils.decode(raylet.ObjectStoreSocketName()) for raylet in raylets + ] + raylet_socket_names = [ + ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets + ] + return { + "node_ip_address": node_ip_address, + "redis_address": redis_address, + "object_store_addresses": object_store_addresses, + "raylet_socket_names": raylet_socket_names, + # Web UI should be running. + "webui_url": _webui_url_helper(redis_client) + } def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - use_raylet=False): + redis_password=None): counter = 0 while True: try: return get_address_info_from_redis_helper( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, node_ip_address, redis_password=redis_password) except Exception: if counter == num_retries: raise @@ -1515,6 +1342,8 @@ def _init(address_info=None, num_workers=None, num_local_schedulers=None, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, local_mode=False, driver_mode=None, redirect_worker_output=False, @@ -1525,10 +1354,15 @@ def _init(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, - use_raylet=None): + driver_id=None, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None, + _internal_config=None): """Helper method to connect to an existing Ray cluster or start a new one. This method handles two cases. Either a Ray cluster already exists and we @@ -1550,12 +1384,15 @@ def _init(address_info=None, object IDs. The same value can be used across multiple runs of the same job in order to generate the object IDs in a consistent manner. However, the same ID should not be used for different jobs. - num_workers (int): The number of workers to start. This is only - provided if start_ray_local is True. num_local_schedulers (int): The number of local schedulers to start. This is only provided if start_ray_local is True. object_store_memory: The maximum amount of memory (in bytes) to allow the object store to use. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. local_mode (bool): True if the code should be executed serially without Ray. This is useful for debugging. redirect_worker_output: True if the stdout and stderr of worker @@ -1577,13 +1414,23 @@ def _init(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. - use_raylet: True if the new raylet code path should be used. + driver_id: The ID of driver. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. + _internal_config (str): JSON configuration for overriding + RayConfig defaults. For testing purposes ONLY. Returns: Address information about the started processes. @@ -1600,10 +1447,10 @@ def _init(address_info=None, else: driver_mode = SCRIPT_MODE - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True + if redis_max_memory and collect_profiling_data: + logger.warn("Profiling data cannot be LRU evicted, so it is disabled " + "when redis_max_memory is set.") + collect_profiling_data = False # Get addresses of existing services. if address_info is None: @@ -1626,12 +1473,8 @@ def _init(address_info=None, # Use 1 local scheduler if num_local_schedulers is not provided. If # existing local schedulers are provided, use that count as # num_local_schedulers. - local_schedulers = address_info.get("local_scheduler_socket_names", []) if num_local_schedulers is None: - if len(local_schedulers) > 0: - num_local_schedulers = len(local_schedulers) - else: - num_local_schedulers = 1 + num_local_schedulers = 1 # Use 1 additional redis shard if num_redis_shards is not provided. num_redis_shards = 1 if num_redis_shards is None else num_redis_shards @@ -1648,6 +1491,8 @@ def _init(address_info=None, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, redirect_worker_output=redirect_worker_output, redirect_output=redirect_output, start_workers_from_local_scheduler=( @@ -1655,10 +1500,14 @@ def _init(address_info=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, - use_raylet=use_raylet) + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=_internal_config) else: if redis_address is None: raise Exception("When connecting to an existing cluster, " @@ -1684,18 +1533,34 @@ def _init(address_info=None, if object_store_memory is not None: raise Exception("When connecting to an existing cluster, " "object_store_memory must not be provided.") + if redis_max_memory is not None: + raise Exception("When connecting to an existing cluster, " + "redis_max_memory must not be provided.") if plasma_directory is not None: raise Exception("When connecting to an existing cluster, " "plasma_directory must not be provided.") if huge_pages: raise Exception("When connecting to an existing cluster, " "huge_pages must not be provided.") + if temp_dir is not None: + raise Exception("When connecting to an existing cluster, " + "temp_dir must not be provided.") + if plasma_store_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "plasma_store_socket_name must not be provided.") + if raylet_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "raylet_socket_name must not be provided.") + if _internal_config is not None: + raise Exception("When connecting to an existing cluster, " + "_internal_config must not be provided.") + # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. address_info = get_address_info_from_redis( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, node_ip_address, redis_password=redis_password) # Connect this driver to Redis, the object store, and the local scheduler. # Choose the first object store and local scheduler if there are multiple. @@ -1707,24 +1572,22 @@ def _init(address_info=None, driver_address_info = { "node_ip_address": node_ip_address, "redis_address": address_info["redis_address"], - "store_socket_name": ( - address_info["object_store_addresses"][0].name), - "webui_url": address_info["webui_url"] + "store_socket_name": address_info["object_store_addresses"][0], + "webui_url": address_info["webui_url"], } - if not use_raylet: - driver_address_info["manager_socket_name"] = ( - address_info["object_store_addresses"][0].manager_name) - driver_address_info["local_scheduler_socket_name"] = ( - address_info["local_scheduler_socket_names"][0]) - else: - driver_address_info["raylet_socket_name"] = ( - address_info["raylet_socket_names"][0]) + driver_address_info["raylet_socket_name"] = ( + address_info["raylet_socket_names"][0]) + + # We only pass `temp_dir` to a worker (WORKER_MODE). + # It can't be a worker here. connect( driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, - use_raylet=use_raylet) + driver_id=driver_id, + redis_password=redis_password, + collect_profiling_data=collect_profiling_data) return address_info @@ -1733,6 +1596,8 @@ def init(redis_address=None, num_gpus=None, resources=None, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, node_ip_address=None, object_id_seed=None, num_workers=None, @@ -1743,14 +1608,19 @@ def init(redis_address=None, ignore_reinit_error=False, num_redis_shards=None, redis_max_clients=None, - redis_protected_mode=True, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, - use_raylet=None, + driver_id=None, configure_logging=True, logging_level=logging.INFO, - logging_format=ray_constants.LOGGER_FORMAT): + logging_format=ray_constants.LOGGER_FORMAT, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None, + _internal_config=None, + use_raylet=None): """Connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -1784,13 +1654,16 @@ def init(redis_address=None, of that resource available. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. node_ip_address (str): The IP address of the node that we are on. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in order to generate the object IDs in a consistent manner. However, the same ID should not be used for different jobs. - num_workers (int): The number of workers to start. This is only - provided if redis_address is not provided. local_mode (bool): True if the code should be executed serially without Ray. This is useful for debugging. redirect_worker_output: True if the stdout and stderr of worker @@ -1803,18 +1676,28 @@ def init(redis_address=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. - use_raylet: True if the new raylet code path should be used. + driver_id: The ID of driver. configure_logging: True if allow the logging cofiguration here. Otherwise, the users may want to configure it by their own. - logging_level: Logging level, default will be loging.INFO. + logging_level: Logging level, default will be logging.INFO. logging_format: Logging format, default will be "%(message)s" which means only contains the message. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. + _internal_config (str): JSON configuration for overriding + RayConfig defaults. For testing purposes ONLY. Returns: Address information about the started processes. @@ -1823,9 +1706,25 @@ def init(redis_address=None, Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ + if configure_logging: logging.basicConfig(level=logging_level, format=logging_format) + # Add the use_raylet option for backwards compatibility. + if use_raylet is not None: + if use_raylet: + logger.warn("WARNING: The use_raylet argument has been " + "deprecated. Please remove it.") + else: + raise DeprecationWarning("The use_raylet argument is deprecated. " + "Please remove it.") + + if setproctitle is None: + logger.warning( + "WARNING: Not updating worker name since `setproctitle` is not " + "installed. Install this with `pip install setproctitle` " + "(or ray[debug]) to enable monitoring of worker processes.") + if global_worker.connected: if ignore_reinit_error: logger.error("Calling ray.init() again after it has already been " @@ -1834,11 +1733,6 @@ def init(redis_address=None, else: raise Exception("Perhaps you called ray.init twice by accident?") - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True - # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -1850,6 +1744,7 @@ def init(redis_address=None, address_info=info, start_ray_local=(redis_address is None), num_workers=num_workers, + object_id_seed=object_id_seed, local_mode=local_mode, driver_mode=driver_mode, redirect_worker_output=redirect_worker_output, @@ -1859,11 +1754,18 @@ def init(redis_address=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, - use_raylet=use_raylet) + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, + driver_id=driver_id, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=_internal_config) for hook in _post_init_hooks: hook() return ret @@ -1899,19 +1801,15 @@ def shutdown(worker=global_worker): worker.plasma_client.disconnect() if worker.mode == SCRIPT_MODE: - # If this is a driver, push the finish time to Redis and clean up any - # other services that were started with the driver. - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, - {"end_time": time.time()}) services.cleanup() else: # If this is not a driver, make sure there are no orphan processes, # besides possibly the worker itself. for process_type, processes in services.all_processes.items(): if process_type == services.PROCESS_TYPE_WORKER: - assert (len(processes)) <= 1 + assert len(processes) <= 1 else: - assert (len(processes) == 0) + assert len(processes) == 0 worker.set_mode(None) @@ -1942,9 +1840,6 @@ def print_error_messages_raylet(worker): This runs in a separate thread on the driver and prints error messages in the background. """ - if not worker.use_raylet: - raise Exception("This function is specific to the raylet code path.") - worker.error_message_pubsub_client = worker.redis_client.pubsub( ignore_subscribe_messages=True) # Exports that are published after the call to @@ -2014,12 +1909,6 @@ def print_error_messages(worker): worker.error_message_pubsub_client.subscribe("__keyspace@0__:ErrorKeys") num_errors_received = 0 - # Keep a set of all the error messages that we've seen so far in order to - # avoid printing the same error message repeatedly. This is especially - # important when running a script inside of a tool like screen where - # scrolling is difficult. - old_error_messages = set() - # Get the exports that occurred before the call to subscribe. with worker.lock: error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) @@ -2027,11 +1916,7 @@ def print_error_messages(worker): if error_applies_to_driver(error_key, worker=worker): error_message = ray.utils.decode( worker.redis_client.hget(error_key, "message")) - if error_message not in old_error_messages: - logger.error(error_message) - old_error_messages.add(error_message) - else: - logger.error("Suppressing duplicate error message.") + logger.error(error_message) num_errors_received += 1 try: @@ -2042,12 +1927,7 @@ def print_error_messages(worker): if error_applies_to_driver(error_key, worker=worker): error_message = ray.utils.decode( worker.redis_client.hget(error_key, "message")) - if error_message not in old_error_messages: - logger.error(error_message) - old_error_messages.add(error_message) - else: - logger.error( - "Suppressing duplicate error message.") + logger.error(error_message) num_errors_received += 1 except redis.ConnectionError: # When Redis terminates the listen call will throw a ConnectionError, @@ -2059,7 +1939,9 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, - use_raylet=False): + driver_id=None, + redis_password=None, + collect_profiling_data=True): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -2069,15 +1951,39 @@ def connect(info, deterministic. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. - use_raylet: True if the new raylet code path should be used. + driver_id: The ID of driver. If it's None, then we will generate one. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. + collect_profiling_data: Whether to collect profiling data from workers. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" assert not worker.connected, error_message assert worker.cached_functions_to_run is not None, error_message - assert worker.cached_remote_functions_and_actors is not None, error_message + + # Enable nice stack traces on SIGSEGV etc. + faulthandler.enable(all_threads=False) + + if collect_profiling_data: + worker.profiler = profiling.Profiler(worker) + else: + worker.profiler = profiling.NoopProfiler() + # Initialize some fields. - worker.worker_id = random_string() + if mode is WORKER_MODE: + worker.worker_id = random_string() + if setproctitle: + setproctitle.setproctitle("ray_worker") + else: + # This is the code path of driver mode. + if driver_id is None: + driver_id = ray.ObjectID(random_string()) + + if not isinstance(driver_id, ray.ObjectID): + raise Exception( + "The type of given driver id must be ray.ObjectID.") + + worker.worker_id = driver_id.id() # When tasks are executed on remote workers in the context of multiple # drivers, the task driver ID is used to keep track of which driver is @@ -2091,7 +1997,6 @@ def connect(info, worker.actor_id = NIL_ACTOR_ID worker.connected = True worker.set_mode(mode) - worker.use_raylet = use_raylet # If running Ray in LOCAL_MODE, there is no need to create call # create_worker or to start the worker service. @@ -2104,7 +2009,10 @@ def connect(info, # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = thread_safe_client( - redis.StrictRedis(host=redis_ip_address, port=int(redis_port))) + redis.StrictRedis( + host=redis_ip_address, + port=int(redis_port), + password=redis_password)) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -2117,7 +2025,6 @@ def connect(info, traceback_str = traceback.format_exc() ray.utils.push_error_to_driver_through_redis( worker.redis_client, - worker.use_raylet, ray_constants.VERSION_MISMATCH_PUSH_ERROR, traceback_str, driver_id=None) @@ -2135,16 +2042,19 @@ def connect(info, else: redirect_worker_output = 0 if redirect_worker_output: - log_stdout_file, log_stderr_file = services.new_log_files( - "worker", True) + log_stdout_file, log_stderr_file = ( + tempfile_services.new_worker_redirected_log_file( + worker.worker_id)) sys.stdout = log_stdout_file sys.stderr = log_stderr_file services.record_log_files_in_redis( - info["redis_address"], info["node_ip_address"], - [log_stdout_file, log_stderr_file]) + info["redis_address"], + info["node_ip_address"], [log_stdout_file, log_stderr_file], + password=redis_password) # Create an object for interfacing with the global state. - global_state._initialize_global_state(redis_ip_address, int(redis_port)) + global_state._initialize_global_state( + redis_ip_address, int(redis_port), redis_password=redis_password) # Register the worker with Redis. if mode == SCRIPT_MODE: @@ -2156,14 +2066,13 @@ def connect(info, "driver_id": worker.worker_id, "start_time": time.time(), "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info.get("manager_socket_name"), - "local_scheduler_socket": info.get("local_scheduler_socket_name"), "raylet_socket": info.get("raylet_socket_name") } driver_info["name"] = (main.__file__ if hasattr(main, "__file__") else "INTERACTIVE MODE") worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) - if not worker.redis_client.exists("webui"): + if (not worker.redis_client.exists("webui") + and info["webui_url"] is not None): worker.redis_client.hmset("webui", {"url": info["webui_url"]}) is_worker = False elif mode == WORKER_MODE: @@ -2171,8 +2080,6 @@ def connect(info, worker_dict = { "node_ip_address": worker.node_ip_address, "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], - "local_scheduler_socket": info["local_scheduler_socket_name"] } if redirect_worker_output: worker_dict["stdout_file"] = os.path.abspath(log_stdout_file.name) @@ -2183,18 +2090,10 @@ def connect(info, raise Exception("This code should be unreachable.") # Create an object store client. - if not worker.use_raylet: - worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], - info["manager_socket_name"], 64)) - else: - worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], "", 64)) + worker.plasma_client = thread_safe_client( + plasma.connect(info["store_socket_name"], "", 64)) - if not worker.use_raylet: - local_scheduler_socket = info["local_scheduler_socket_name"] - else: - local_scheduler_socket = info["raylet_socket_name"] + raylet_socket = info["raylet_socket_name"] # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -2225,28 +2124,22 @@ def connect(info, # rerun the driver. nil_actor_counter = 0 - driver_task = ray.local_scheduler.Task( - worker.task_driver_id, ray.ObjectID(NIL_FUNCTION_ID), [], 0, - worker.current_task_id, worker.task_index, - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, False, [], {"CPU": 0}, worker.use_raylet) + driver_task = ray.raylet.Task(worker.task_driver_id, + ray.ObjectID(NIL_FUNCTION_ID), [], 0, + worker.current_task_id, + worker.task_index, + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + nil_actor_counter, [], {"CPU": 0}, {}) # Add the driver task to the task table. - if not worker.use_raylet: - global_state._execute_command( - driver_task.task_id(), "RAY.TASK_TABLE_ADD", - driver_task.task_id().id(), TASK_STATUS_RUNNING, - NIL_LOCAL_SCHEDULER_ID, - driver_task.execution_dependencies_string(), 0, - ray.local_scheduler.task_to_string(driver_task)) - else: - global_state._execute_command( - driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().id(), - driver_task._serialized_raylet_task()) + global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, + driver_task.task_id().id(), + driver_task._serialized_raylet_task()) # Set the driver's current task ID to the task ID assigned to the # driver task. @@ -2254,10 +2147,12 @@ def connect(info, else: # A non-driver worker begins without an assigned task. worker.current_task_id = ray.ObjectID(NIL_ID) + # A flag for making sure that we only print one warning message about + # multithreading per worker. + worker.multithreading_warned = False - worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( - local_scheduler_socket, worker.worker_id, is_worker, - worker.current_task_id, worker.use_raylet) + worker.local_scheduler_client = ray.raylet.LocalSchedulerClient( + raylet_socket, worker.worker_id, is_worker, worker.current_task_id) # Start the import thread import_thread.ImportThread(worker, mode).start() @@ -2269,16 +2164,10 @@ def connect(info, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - if not worker.use_raylet: - t = threading.Thread( - target=print_error_messages, - name="ray_print_error_messages", - args=(worker, )) - else: - t = threading.Thread( - target=print_error_messages_raylet, - name="ray_print_error_messages", - args=(worker, )) + t = threading.Thread( + target=print_error_messages_raylet, + name="ray_print_error_messages", + args=(worker, )) # Making the thread a daemon causes it to exit when the main thread # exits. t.daemon = True @@ -2286,7 +2175,7 @@ def connect(info, # If we are using the raylet code path and we are not in local mode, start # a background thread to periodically flush profiling data to the GCS. - if mode != LOCAL_MODE and worker.use_raylet: + if mode != LOCAL_MODE: worker.profiler.start_flush_thread() if mode == SCRIPT_MODE: @@ -2312,18 +2201,9 @@ def connect(info, # Export cached functions_to_run. for function in worker.cached_functions_to_run: worker.run_function_on_all_workers(function) - # Export cached remote functions to the workers. - for cached_type, info in worker.cached_remote_functions_and_actors: - if cached_type == "remote_function": - info._export() - elif cached_type == "actor": - (key, actor_class_info) = info - ray.actor.publish_actor_class_to_key(key, actor_class_info, - worker) - else: - assert False, "This code should be unreachable." + # Export cached remote functions and actors to the workers. + worker.function_actor_manager.export_cached() worker.cached_functions_to_run = None - worker.cached_remote_functions_and_actors = None def disconnect(worker=global_worker): @@ -2334,10 +2214,19 @@ def disconnect(worker=global_worker): # tests. worker.connected = False worker.cached_functions_to_run = [] - worker.cached_remote_functions_and_actors = [] + worker.function_actor_manager.reset_cache() worker.serialization_context_map.clear() +@contextmanager +def _changeproctitle(title, next_title): + if setproctitle: + setproctitle.setproctitle(title) + yield + if setproctitle: + setproctitle.setproctitle(next_title) + + def _try_to_compute_deterministic_class_id(cls, depth=5): """Attempt to produce a deterministic class ID for a given class. @@ -2444,7 +2333,7 @@ def register_custom_serializer(cls, # worker. However, determinism is not guaranteed, and the # result may be different on different workers. class_id = _try_to_compute_deterministic_class_id(cls) - except Exception as e: + except Exception: raise serialization.CloudPickleError("Failed to pickle class " "'{}'".format(cls)) else: @@ -2452,6 +2341,9 @@ def register_custom_serializer(cls, # worker and not across workers. class_id = random_string() + # Make sure class_id is a string. + class_id = ray.utils.binary_to_hex(class_id) + if driver_id is None: driver_id_bytes = worker.task_driver_id.id() else: @@ -2538,7 +2430,7 @@ def put(value, worker=global_worker): # In LOCAL_MODE, ray.put is the identity operation. return value object_id = worker.local_scheduler_client.compute_put_id( - worker.current_task_id, worker.put_index, worker.use_raylet) + worker.current_task_id, worker.put_index) worker.put_object(object_id, value) worker.put_index += 1 return object_id @@ -2590,6 +2482,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): type(object_id))) worker.check_connected() + # TODO(swang): Check main thread. with profiling.profile("ray.wait", worker=worker): # When Ray is run in LOCAL_MODE, all functions are run immediately, # so all objects in object_id are ready. @@ -2610,22 +2503,14 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): if num_returns > len(object_ids): raise Exception("num_returns cannot be greater than the number " "of objects provided to ray.wait.") + + # Get the task ID, to notify the backend which task is blocked. + with worker.state_lock: + current_task_id = worker.get_current_thread_task_id() + timeout = timeout if timeout is not None else 2**30 - if worker.use_raylet: - ready_ids, remaining_ids = worker.local_scheduler_client.wait( - object_ids, num_returns, timeout, False) - else: - object_id_strs = [ - plasma.ObjectID(object_id.id()) for object_id in object_ids - ] - ready_ids, remaining_ids = worker.plasma_client.wait( - object_id_strs, timeout, num_returns) - ready_ids = [ - ray.ObjectID(object_id.binary()) for object_id in ready_ids - ] - remaining_ids = [ - ray.ObjectID(object_id.binary()) for object_id in remaining_ids - ] + ready_ids, remaining_ids = worker.local_scheduler_client.wait( + object_ids, num_returns, timeout, False, current_task_id) return ready_ids, remaining_ids diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index cd7b3f4a45c35..dc1085783b8aa 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -9,6 +9,7 @@ import ray import ray.actor import ray.ray_constants as ray_constants +import ray.tempfile_services as tempfile_services parser = argparse.ArgumentParser( description=("Parse addresses for the worker " @@ -23,6 +24,12 @@ required=True, type=str, help="the address to use for Redis") +parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--object-store-name", required=True, @@ -33,11 +40,6 @@ required=False, type=str, help="the object store manager's name") -parser.add_argument( - "--local-scheduler-name", - required=False, - type=str, - help="the local scheduler's name") parser.add_argument( "--raylet-name", required=False, type=str, help="the raylet's name") parser.add_argument( @@ -53,6 +55,17 @@ type=str, default=ray_constants.LOGGER_FORMAT, help=ray_constants.LOGGER_FORMAT_HELP) +parser.add_argument( + "--collect-profiling-data", + type=int, # int since argparse can't handle bool values + default=1, + help="Whether to collect profiling data from workers.") +parser.add_argument( + "--temp-dir", + required=False, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.") if __name__ == "__main__": args = parser.parse_args() @@ -60,18 +73,24 @@ info = { "node_ip_address": args.node_ip_address, "redis_address": args.redis_address, + "redis_password": args.redis_password, "store_socket_name": args.object_store_name, "manager_socket_name": args.object_store_manager_name, - "local_scheduler_socket_name": args.local_scheduler_name, - "raylet_socket_name": args.raylet_name + "raylet_socket_name": args.raylet_name, } logging.basicConfig( level=logging.getLevelName(args.logging_level.upper()), format=args.logging_format) + # Override the temporary directory. + tempfile_services.set_temp_root(args.temp_dir) + ray.worker.connect( - info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None)) + info, + mode=ray.WORKER_MODE, + redis_password=args.redis_password, + collect_profiling_data=args.collect_profiling_data) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker @@ -86,7 +105,7 @@ # main_loop. If an exception is thrown here, then that means that # there is some error that we didn't anticipate. ray.worker.global_worker.main_loop() - except Exception as e: + except Exception: traceback_str = traceback.format_exc() + error_explanation ray.utils.push_error_to_driver( ray.worker.global_worker, diff --git a/python/setup.py b/python/setup.py index 70d7cd87fadb7..c92ffa65b481d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -3,6 +3,7 @@ from __future__ import print_function import os +import re import shutil import subprocess import sys @@ -19,13 +20,10 @@ # NOTE: The lists below must be kept in sync with ray/CMakeLists.txt. ray_files = [ - "ray/core/src/common/thirdparty/redis/src/redis-server", - "ray/core/src/common/redis_module/libray_redis_module.so", + "ray/core/src/ray/thirdparty/redis/src/redis-server", + "ray/core/src/ray/gcs/redis_module/libray_redis_module.so", "ray/core/src/plasma/plasma_store_server", - "ray/core/src/plasma/plasma_manager", - "ray/core/src/local_scheduler/local_scheduler", - "ray/core/src/local_scheduler/liblocal_scheduler_library_python.so", - "ray/core/src/global_scheduler/global_scheduler", + "ray/core/src/ray/raylet/liblocal_scheduler_library_python.so", "ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet", "ray/WebUI.ipynb" ] @@ -47,6 +45,7 @@ ray_autoscaler_files = [ "ray/autoscaler/aws/example-full.yaml", "ray/autoscaler/gcp/example-full.yaml", + "ray/autoscaler/local/example-full.yaml", ] if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on": @@ -65,7 +64,10 @@ optional_ray_files += ray_autoscaler_files -extras = {"rllib": ["pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"]} +extras = { + "rllib": ["pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"], + "debug": ["psutil", "setproctitle", "py-spy"], +} class build_ext(_build_ext.build_ext): @@ -98,7 +100,7 @@ def run(self): for filename in optional_ray_files: try: self.move_file(filename) - except Exception as e: + except Exception: print("Failed to copy optional file {}. This is ok." .format(filename)) @@ -121,26 +123,47 @@ def has_ext_modules(self): return True +def find_version(*filepath): + # Extract version information from filepath + here = os.path.abspath(os.path.dirname(__file__)) + with open(os.path.join(here, *filepath)) as fp: + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", + fp.read(), re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + + +requires = [ + "numpy", + "funcsigs", + "click", + "colorama", + "pytest", + "pyyaml", + "redis", + # The six module is required by pyarrow. + "six >= 1.0.0", + "flatbuffers", +] + +if sys.version_info < (3, 0): + requires.append("faulthandler") + setup( name="ray", - # The version string is also in __init__.py. TODO(pcm): Fix this. - version="0.5.3", + version=find_version("ray", "__init__.py"), + description=("A system for parallel and distributed Python that unifies " + "the ML ecosystem."), + long_description=open("../README.rst").read(), + url="https://github.com/ray-project/ray", + keywords=("ray distributed parallel machine-learning " + "reinforcement-learning deep-learning python"), packages=find_packages(), cmdclass={"build_ext": build_ext}, # The BinaryDistribution argument triggers build_ext. distclass=BinaryDistribution, - install_requires=[ - "numpy", - "funcsigs", - "click", - "colorama", - "pytest", - "pyyaml", - "redis", - # The six module is required by pyarrow. - "six >= 1.0.0", - "flatbuffers" - ], + install_requires=requires, setup_requires=["cython >= 0.27, < 0.28"], extras_require=extras, entry_points={ diff --git a/site/Gemfile b/site/Gemfile index 8af267397b31b..9ae4bf67ff67d 100644 --- a/site/Gemfile +++ b/site/Gemfile @@ -9,7 +9,7 @@ ruby RUBY_VERSION # # This will help ensure the proper Jekyll version is running. # Happy Jekylling! -gem "jekyll", "3.4.3" +gem "jekyll", ">= 3.6.3" # This is the default theme for new Jekyll sites. You may change this to anything you like. gem "minima", "~> 2.0" diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt deleted file mode 100644 index b024b4a0419f3..0000000000000 --- a/src/common/CMakeLists.txt +++ /dev/null @@ -1,131 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(common) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${CMAKE_CURRENT_LIST_DIR}/lib/python") -endif () - -add_subdirectory(redis_module) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g") - -include_directories(thirdparty/ae) - -# Compile flatbuffers - -set(COMMON_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/common.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(COMMON_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/common_generated.h") - -add_custom_target(gen_common_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - -add_custom_command( - OUTPUT ${COMMON_FBS_OUTPUT_FILES} - # The --gen-object-api flag generates a C++ class MessageT for each - # flatbuffers message Message, which can be used to store deserialized - # messages in data structures. This is currently used for ObjectInfo for - # example. - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${COMMON_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - add_custom_target(gen_common_python_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - - # Generate Python bindings for the flatbuffers objects. - set(PYTHON_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/../../python/ray/core/generated/) - add_custom_command( - TARGET gen_common_python_fbs - COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${COMMON_FBS_SRC} - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - - # Encode the fact that the ray redis module requires the autogenerated - # flatbuffer files to compile. - add_dependencies(ray_redis_module gen_common_python_fbs) - - add_dependencies(gen_common_python_fbs flatbuffers_ep) -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_custom_target(gen_common_java_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - - # Generate Java bindings for the flatbuffers objects. - set(JAVA_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/java) - add_custom_command( - TARGET gen_common_java_fbs - COMMAND ${FLATBUFFERS_COMPILER} -j -o ${JAVA_OUTPUT_DIR} ${COMMON_FBS_SRC} - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - - # Encode the fact that the ray redis module requires the autogenerated - # flatbuffer files to compile. - add_dependencies(ray_redis_module gen_common_java_fbs) - - add_dependencies(gen_common_java_fbs flatbuffers_ep) -endif() - -add_custom_target( - hiredis - COMMAND make - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis) - -add_library(common STATIC - event_loop.cc - common.cc - common_protocol.cc - task.cc - io.cc - net.cc - logging.cc - state/redis.cc - state/table.cc - state/object_table.cc - state/task_table.cc - state/db_client_table.cc - state/driver_table.cc - state/actor_notification_table.cc - state/local_scheduler_table.cc - state/error_table.cc - thirdparty/ae/ae.c - thirdparty/sha256.c) - -add_dependencies(common arrow) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - add_dependencies(common gen_common_python_fbs) -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_dependencies(common gen_common_java_fbs) -endif() - -target_link_libraries(common "${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis/libhiredis.a") - -function(define_test test_name library) - add_executable(${test_name} test/${test_name}.cc ${ARGN}) - add_dependencies(${test_name} hiredis flatbuffers_ep) - target_link_libraries(${test_name} common ${FLATBUFFERS_STATIC_LIB} ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${library} -lpthread) - target_compile_options(${test_name} PUBLIC "-DPLASMA_TEST -DLOCAL_SCHEDULER_TEST -DCOMMON_TEST -DRAY_COMMON_LOG_LEVEL=4") -endfunction() - -define_test(db_tests "") -define_test(io_tests "") -define_test(task_tests "") -define_test(redis_tests "") -define_test(task_table_tests "") -define_test(object_table_tests "") - -add_custom_target(copy_redis ALL) -foreach(file "redis-cli" "redis-server") -add_custom_command(TARGET copy_redis POST_BUILD - COMMAND ${CMAKE_COMMAND} -E - copy ${CMAKE_CURRENT_LIST_DIR}/../../thirdparty/pkg/redis/src/${file} - ${CMAKE_BINARY_DIR}/src/common/thirdparty/redis/src/${file}) -endforeach() diff --git a/src/common/common.cc b/src/common/common.cc deleted file mode 100644 index 0a6da6a2936e8..0000000000000 --- a/src/common/common.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "common.h" - -#include -#include -#include -#include -#include -#include - -#include "io.h" -#include - -const unsigned char NIL_DIGEST[DIGEST_SIZE] = {0}; - -int64_t current_time_ms() { - std::chrono::milliseconds ms_since_epoch = - std::chrono::duration_cast( - std::chrono::steady_clock::now().time_since_epoch()); - return ms_since_epoch.count(); -} diff --git a/src/common/common.h b/src/common/common.h deleted file mode 100644 index f95bfcca5d262..0000000000000 --- a/src/common/common.h +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef COMMON_H -#define COMMON_H - -#include -#include -#include -#ifndef __STDC_FORMAT_MACROS -#define __STDC_FORMAT_MACROS -#endif -#include -#include -#ifndef _WIN32 -#include -#endif - -#ifdef __cplusplus -#include -extern "C" { -#endif -#include "sha256.h" -#ifdef __cplusplus -} -#endif - -#include "arrow/util/macros.h" -#include "plasma/common.h" -#include "ray/id.h" -#include "ray/util/logging.h" - -#include "state/ray_config.h" - -/** Definitions for Ray logging levels. */ -#define RAY_COMMON_DEBUG 0 -#define RAY_COMMON_INFO 1 -#define RAY_COMMON_WARNING 2 -#define RAY_COMMON_ERROR 3 -#define RAY_COMMON_FATAL 4 - -/** - * RAY_COMMON_LOG_LEVEL should be defined to one of the above logging level - * integer values. Any logging statement in the code with a logging level - * greater than or equal to RAY_COMMON_LOG_LEVEL will be outputted to stderr. - * The default logging level is INFO. */ -#ifndef RAY_COMMON_LOG_LEVEL -#define RAY_COMMON_LOG_LEVEL RAY_COMMON_INFO -#endif - -/* These are exit codes for common errors that can occur in Ray components. */ -#define EXIT_COULD_NOT_BIND_PORT -2 - -/** This macro indicates that this pointer owns the data it is pointing to - * and is responsible for freeing it. */ -#define OWNER - -/** The worker ID is the ID of a worker or driver. */ -typedef ray::UniqueID WorkerID; - -typedef ray::UniqueID DBClientID; - -#define MAX(x, y) ((x) >= (y) ? (x) : (y)) -#define MIN(x, y) ((x) <= (y) ? (x) : (y)) - -/** Definitions for computing hash digests. */ -#define DIGEST_SIZE SHA256_BLOCK_SIZE - -extern const unsigned char NIL_DIGEST[DIGEST_SIZE]; - -/** - * Return the current time in milliseconds since the Unix epoch. - * - * @return The number of milliseconds since the Unix epoch. - */ -int64_t current_time_ms(); - -#endif diff --git a/src/common/doc/tasks.md b/src/common/doc/tasks.md deleted file mode 100644 index 4431afae2ee9e..0000000000000 --- a/src/common/doc/tasks.md +++ /dev/null @@ -1,32 +0,0 @@ -# Task specifications, task instances and task logs - -A *task specification* contains all information that is needed for computing -the results of a task: - -- The ID of the task -- The function ID of the function that executes the task -- The arguments (either object IDs for pass by reference -or values for pass by value) -- The IDs of the result objects - -From these, a task ID can be computed which is also stored in the task -specification. - -A *task* represents the execution of a task specification. -It consists of: - -- A scheduling state (WAITING, SCHEDULED, RUNNING, DONE) -- The target node where the task is scheduled or executed -- The task specification - -The task data structures are defined in `common/task.h`. - -The *task table* is a mapping from the task ID to the *task* information. It is -updated by various parts of the system: - -1. The local scheduler writes it with status WAITING when submits a task to the global scheduler -2. The global scheduler appends an update WAITING -> SCHEDULED together with the node ID when assigning the task to a local scheduler -3. The local scheduler appends an update SCHEDULED -> RUNNING when it assigns a task to a worker -4. The local scheduler appends an update RUNNING -> DONE when the task finishes execution - -The task table is defined in `common/state/task_table.h`. diff --git a/src/common/event_loop.cc b/src/common/event_loop.cc deleted file mode 100644 index e3d9cc4a2dc6f..0000000000000 --- a/src/common/event_loop.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include "event_loop.h" - -#include "common.h" -#include - -#define INITIAL_EVENT_LOOP_SIZE 1024 - -event_loop *event_loop_create(void) { - return aeCreateEventLoop(INITIAL_EVENT_LOOP_SIZE); -} - -void event_loop_destroy(event_loop *loop) { - /* Clean up timer events. This is to make valgrind happy. */ - aeTimeEvent *te = loop->timeEventHead; - while (te) { - aeTimeEvent *next = te->next; - free(te); - te = next; - } - aeDeleteEventLoop(loop); -} - -bool event_loop_add_file(event_loop *loop, - int fd, - int events, - event_loop_file_handler handler, - void *context) { - /* Try to add the file descriptor. */ - int err = aeCreateFileEvent(loop, fd, events, handler, context); - /* If it cannot be added, increase the size of the event loop. */ - if (err == AE_ERR && errno == ERANGE) { - err = aeResizeSetSize(loop, 3 * aeGetSetSize(loop) / 2); - if (err != AE_OK) { - return false; - } - err = aeCreateFileEvent(loop, fd, events, handler, context); - } - /* In any case, test if there were errors. */ - return (err == AE_OK); -} - -void event_loop_remove_file(event_loop *loop, int fd) { - aeDeleteFileEvent(loop, fd, EVENT_LOOP_READ | EVENT_LOOP_WRITE); -} - -int64_t event_loop_add_timer(event_loop *loop, - int64_t timeout, - event_loop_timer_handler handler, - void *context) { - return aeCreateTimeEvent(loop, timeout, handler, context, NULL); -} - -int event_loop_remove_timer(event_loop *loop, int64_t id) { - return aeDeleteTimeEvent(loop, id); -} - -void event_loop_run(event_loop *loop) { - aeMain(loop); -} - -void event_loop_stop(event_loop *loop) { - aeStop(loop); -} diff --git a/src/common/event_loop.h b/src/common/event_loop.h deleted file mode 100644 index e489ab4fb6729..0000000000000 --- a/src/common/event_loop.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef EVENT_LOOP_H -#define EVENT_LOOP_H - -#include - -extern "C" { -#ifdef _WIN32 -/* Quirks mean that Windows version needs to be included differently */ -#include -#include -#else -#include "ae/ae.h" -#endif -} - -/* Unique timer ID that will be generated when the timer is added to the - * event loop. Will not be reused later on in another call - * to event_loop_add_timer. */ -typedef long long timer_id; - -typedef aeEventLoop event_loop; - -/* File descriptor is readable. */ -#define EVENT_LOOP_READ AE_READABLE - -/* File descriptor is writable. */ -#define EVENT_LOOP_WRITE AE_WRITABLE - -/* Constant specifying that the timer is done and it will be removed. */ -#define EVENT_LOOP_TIMER_DONE AE_NOMORE - -/* Signature of the handler that will be called when there is a new event - * on the file descriptor that this handler has been registered for. The - * context is the one that was passed into add_file by the user. The - * events parameter indicates which event is available on the file, - * it can be EVENT_LOOP_READ or EVENT_LOOP_WRITE. */ -typedef void (*event_loop_file_handler)(event_loop *loop, - int fd, - void *context, - int events); - -/* This handler will be called when a timer times out. The id of the timer - * as well as the context that was specified when registering this handler - * are passed as arguments. The return is the number of milliseconds the - * timer shall be reset to or EVENT_LOOP_TIMER_DONE if the timer shall - * not be triggered again. */ -typedef int (*event_loop_timer_handler)(event_loop *loop, - timer_id timer_id, - void *context); - -/* Create and return a new event loop. */ -event_loop *event_loop_create(void); - -/* Deallocate space associated with the event loop that was created - * with the "create" function. */ -void event_loop_destroy(event_loop *loop); - -/* Register a handler that will be called any time a new event happens on - * a file descriptor. Can specify a context that will be passed as an - * argument to the handler. Currently there can only be one handler per file. - * The events parameter specifies which events we listen to: EVENT_LOOP_READ - * or EVENT_LOOP_WRITE. */ -bool event_loop_add_file(event_loop *loop, - int fd, - int events, - event_loop_file_handler handler, - void *context); - -/* Remove a registered file event handler from the event loop. */ -void event_loop_remove_file(event_loop *loop, int fd); - -/** Register a handler that will be called after a time slice of - * "timeout" milliseconds. - * - * @param loop The event loop. - * @param timeout The timeout in milliseconds. - * @param handler The handler for the timeout. - * @param context User context that can be passed in and will be passed in - * as an argument for the timer handler. - * @return The ID of the timer. - */ -int64_t event_loop_add_timer(event_loop *loop, - int64_t timeout, - event_loop_timer_handler handler, - void *context); - -/** - * Remove a registered time event handler from the event loop. Can be called - * multiple times on the same timer. - * - * @param loop The event loop. - * @param timer_id The ID of the timer to be removed. - * @return Returns 0 if the removal was successful. - */ -int event_loop_remove_timer(event_loop *loop, int64_t timer_id); - -/* Run the event loop. */ -void event_loop_run(event_loop *loop); - -/* Stop the event loop. */ -void event_loop_stop(event_loop *loop); - -#endif diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs deleted file mode 100644 index 9dc9f651a3e32..0000000000000 --- a/src/common/format/common.fbs +++ /dev/null @@ -1,200 +0,0 @@ - -// Indices into resource vectors. -// A resource vector maps a resource index to the number -// of units of that resource required. - -table Arg { - // Object ID for pass-by-reference arguments. Normally there is only one - // object ID in this list which represents the object that is being passed. - // However to support reducers in a MapReduce workload, we also support - // passing multiple object IDs for each argument. - object_ids: [string]; - // Data for pass-by-value arguments. - data: string; -} - -table ResourcePair { - // The name of the resource. - key: string; - // The quantity of the resource. - value: double; -} - -// NOTE: This enum is duplicate with the `Language` enum in `gcs.fbs`, -// because we cannot include this file in `gcs.fbs` due to cyclic dependency. -// TODO(raulchen): remove it once we get rid of legacy ray. -enum TaskLanguage:int { - PYTHON = 0, - JAVA = 1 -} - -table TaskInfo { - // ID of the driver that created this task. - driver_id: string; - // Task ID of the task. - task_id: string; - // Task ID of the parent task. - parent_task_id: string; - // A count of the number of tasks submitted by the parent task before this one. - parent_counter: int; - // The ID of the actor to create if this is an actor creation task. - actor_creation_id: string; - // The dummy object ID of the actor creation task if this is an actor method. - actor_creation_dummy_object_id: string; - // Actor ID of the task. This is the actor that this task is executed on - // or NIL_ACTOR_ID if the task is just a normal task. - actor_id: string; - // The ID of the handle that was used to submit the task. This should be - // unique across handles with the same actor_id. - actor_handle_id: string; - // Number of tasks that have been submitted to this actor so far. - actor_counter: int; - // True if this task is an actor checkpoint task and false otherwise. - is_actor_checkpoint_method: bool; - // Function ID of the task. - function_id: string; - // Task arguments. - args: [Arg]; - // Object IDs of return values. - returns: [string]; - // The required_resources vector indicates the quantities of the different - // resources required by this task. - required_resources: [ResourcePair]; - // The language that this task belongs to - language: TaskLanguage; - // Function descriptor, which is a list of strings that can - // uniquely describe a function. - // For a Python function, it should be: [module_name, class_name, function_name] - // For a Java function, it should be: [class_name, method_name, type_descriptor] - // TODO(hchen): after changing Python worker to use function_descriptor, - // function_id can be removed. - function_descriptor: [string]; -} - -// Object information data structure. -// NOTE(pcm): This structure is replicated in -// https://github.com/apache/arrow/blob/master/cpp/src/plasma/format/common.fbs, -// so if you modify it, you should also modify that one. -table ObjectInfo { - // Object ID of this object. - object_id: string; - // Number of bytes the content of this object occupies in memory. - data_size: long; - // Number of bytes the metadata of this object occupies in memory. - metadata_size: long; - // Number of clients using the objects. - ref_count: int; - // Unix epoch of when this object was created. - create_time: long; - // How long creation of this object took. - construct_duration: long; - // Hash of the object content. If the object is not sealed yet this is - // an empty string. - digest: string; - // Specifies if this object was deleted or added. - is_deletion: bool; -} - -root_type TaskInfo; - -table TaskExecutionDependencies { - // A list of object IDs representing this task's dependencies at execution - // time. - execution_dependencies: [string]; -} - -root_type TaskExecutionDependencies; - -table SubscribeToNotificationsReply { - // The object ID of the object that the notification is about. - object_id: string; - // The size of the object. - object_size: long; - // The IDs of the managers that contain this object. - manager_ids: [string]; -} - -root_type SubscribeToNotificationsReply; - -table TaskReply { - // The task ID of the task that the message is about. - task_id: string; - // The state of the task. This is encoded as a bit mask of scheduling_state - // enum values in task.h. - state: long; - // A local scheduler ID. - local_scheduler_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // A string of bytes representing the task specification. - task_spec: string; - // The number of times the task was spilled back by local schedulers. - spillback_count: long; - // A boolean representing whether the update was successful. This field - // should only be used for test-and-set operations. - updated: bool; -} - -root_type TaskReply; - -table SubscribeToDBClientTableReply { - // The db client ID of the client that the message is about. - db_client_id: string; - // The type of the client. - client_type: string; - // If the client is a local scheduler, this is the address of the plasma - // manager that the local scheduler is connected to. Otherwise, it is empty. - manager_address: string; - // True if the message is about the addition of a client and false if it is - // about the deletion of a client. - is_insertion: bool; -} - -root_type SubscribeToDBClientTableReply; - -table LocalSchedulerInfoMessage { - // The db client ID of the client that the message is about. - db_client_id: string; - // The total number of workers that are connected to this local scheduler. - total_num_workers: long; - // The number of tasks queued in this local scheduler. - task_queue_length: long; - // The number of workers that are available and waiting for tasks. - available_workers: long; - // The resources generally available to this local scheduler. - static_resources: [ResourcePair]; - // The resources currently available to this local scheduler. - dynamic_resources: [ResourcePair]; - // Whether the local scheduler is dead. If true, then all other fields - // besides `db_client_id` will not be set. - is_dead: bool; -} - -root_type LocalSchedulerInfoMessage; - -table ResultTableReply { - // The task ID of the task that created the object. - task_id: string; - // Whether the task created the object through a ray.put. - is_put: bool; - // The size of the object created. - data_size: long; - // The hash of the object created. - hash: string; -} - -root_type ResultTableReply; - -table DriverTableMessage { - // The driver ID of the driver that died. - driver_id: string; -} - -table ActorCreationNotification { - // The ID of the actor that was created. - actor_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the local scheduler that created the actor. - local_scheduler_id: string; -} diff --git a/src/common/io.cc b/src/common/io.cc deleted file mode 100644 index 1999b70546694..0000000000000 --- a/src/common/io.cc +++ /dev/null @@ -1,416 +0,0 @@ -#include "io.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common.h" -#include "event_loop.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -int bind_inet_sock(const int port, bool shall_listen) { - struct sockaddr_in name; - int socket_fd = socket(PF_INET, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for port " << port; - return -1; - } - name.sin_family = AF_INET; - name.sin_port = htons(port); - name.sin_addr.s_addr = htonl(INADDR_ANY); - int on = 1; - /* TODO(pcm): http://stackoverflow.com/q/1150635 */ - if (ioctl(socket_fd, FIONBIO, (char *) &on) < 0) { - RAY_LOG(ERROR) << "ioctl failed"; - close(socket_fd); - return -1; - } - int *const pon = (int *const) & on; - if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, pon, sizeof(on)) < 0) { - RAY_LOG(ERROR) << "setsockopt failed for port " << port; - close(socket_fd); - return -1; - } - if (bind(socket_fd, (struct sockaddr *) &name, sizeof(name)) < 0) { - RAY_LOG(ERROR) << "Bind failed for port " << port; - close(socket_fd); - return -1; - } - if (shall_listen && listen(socket_fd, 128) == -1) { - RAY_LOG(ERROR) << "Could not listen to socket " << port; - close(socket_fd); - return -1; - } - return socket_fd; -} - -int bind_ipc_sock(const char *socket_pathname, bool shall_listen) { - struct sockaddr_un socket_address; - int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - /* Tell the system to allow the port to be reused. */ - int on = 1; - if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, (char *) &on, - sizeof(on)) < 0) { - RAY_LOG(ERROR) << "setsockopt failed for pathname " << socket_pathname; - close(socket_fd); - return -1; - } - - unlink(socket_pathname); - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - close(socket_fd); - return -1; - } - strncpy(socket_address.sun_path, socket_pathname, - strlen(socket_pathname) + 1); - - if (bind(socket_fd, (struct sockaddr *) &socket_address, - sizeof(socket_address)) != 0) { - RAY_LOG(ERROR) << "Bind failed for pathname " << socket_pathname; - close(socket_fd); - return -1; - } - if (shall_listen && listen(socket_fd, 128) == -1) { - RAY_LOG(ERROR) << "Could not listen to socket " << socket_pathname; - close(socket_fd); - return -1; - } - return socket_fd; -} - -int connect_ipc_sock_retry(const char *socket_pathname, - int num_retries, - int64_t timeout) { - /* Pick the default values if the user did not specify. */ - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - - RAY_CHECK(socket_pathname); - int fd = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - fd = connect_ipc_sock(socket_pathname); - if (fd >= 0) { - break; - } - if (num_attempts == 0) { - RAY_LOG(ERROR) << "Connection to socket failed for pathname " - << socket_pathname; - } - /* Sleep for timeout milliseconds. */ - usleep(timeout * 1000); - } - /* If we could not connect to the socket, exit. */ - if (fd == -1) { - RAY_LOG(FATAL) << "Could not connect to socket " << socket_pathname; - } - return fd; -} - -int connect_ipc_sock(const char *socket_pathname) { - struct sockaddr_un socket_address; - int socket_fd; - - socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - return -1; - } - strncpy(socket_address.sun_path, socket_pathname, - strlen(socket_pathname) + 1); - - if (connect(socket_fd, (struct sockaddr *) &socket_address, - sizeof(socket_address)) != 0) { - close(socket_fd); - return -1; - } - - return socket_fd; -} - -int connect_inet_sock_retry(const char *ip_addr, - int port, - int num_retries, - int64_t timeout) { - /* Pick the default values if the user did not specify. */ - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - - RAY_CHECK(ip_addr); - int fd = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - fd = connect_inet_sock(ip_addr, port); - if (fd >= 0) { - break; - } - if (num_attempts == 0) { - RAY_LOG(ERROR) << "Connection to socket failed for address " << ip_addr - << ":" << port; - } - /* Sleep for timeout milliseconds. */ - usleep(timeout * 1000); - } - /* If we could not connect to the socket, exit. */ - if (fd == -1) { - RAY_LOG(FATAL) << "Could not connect to address " << ip_addr << ":" << port; - } - return fd; -} - -int connect_inet_sock(const char *ip_addr, int port) { - int fd = socket(PF_INET, SOCK_STREAM, 0); - if (fd < 0) { - RAY_LOG(ERROR) << "socket() failed for address " << ip_addr << ":" << port; - return -1; - } - - struct hostent *manager = gethostbyname(ip_addr); /* TODO(pcm): cache this */ - if (!manager) { - RAY_LOG(ERROR) << "Failed to get hostname from address " << ip_addr << ":" - << port; - close(fd); - return -1; - } - - struct sockaddr_in addr; - addr.sin_family = AF_INET; - memcpy(&addr.sin_addr.s_addr, manager->h_addr_list[0], manager->h_length); - addr.sin_port = htons(port); - - if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) != 0) { - close(fd); - return -1; - } - return fd; -} - -int accept_client(int socket_fd) { - int client_fd = accept(socket_fd, NULL, NULL); - if (client_fd < 0) { - RAY_LOG(ERROR) << "Error reading from socket."; - return -1; - } - return client_fd; -} - -int write_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - /* While we haven't written the whole message, write to the file - * descriptor, advance the cursor, and decrease the amount left to write. */ - nbytes = write(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) { - int64_t version = RayConfig::instance().ray_protocol_version(); - int closed; - closed = write_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *) &type, sizeof(type)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *) &length, sizeof(length)); - if (closed) { - return closed; - } - closed = write_bytes(fd, bytes, length * sizeof(char)); - if (closed) { - return closed; - } - return 0; -} - -int write_message(int fd, - int64_t type, - int64_t length, - uint8_t *bytes, - std::mutex *mutex) { - if (mutex != NULL) { - std::unique_lock guard(*mutex); - return do_write_message(fd, type, length, bytes); - } else { - return do_write_message(fd, type, length, bytes); - } -} - -int read_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - /* Termination condition: EOF or read 'length' bytes total. */ - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - nbytes = read(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes) { - int64_t version; - int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - goto disconnected; - } - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); - closed = read_bytes(fd, (uint8_t *) type, sizeof(*type)); - if (closed) { - goto disconnected; - } - closed = read_bytes(fd, (uint8_t *) length, sizeof(*length)); - if (closed) { - goto disconnected; - } - *bytes = (uint8_t *) malloc(*length * sizeof(uint8_t)); - closed = read_bytes(fd, *bytes, *length); - if (closed) { - free(*bytes); - goto disconnected; - } - return; - -disconnected: - /* Handle the case in which the socket is closed. */ - *type = static_cast(CommonMessageType::DISCONNECT_CLIENT); - *length = 0; - *bytes = NULL; - return; -} - -uint8_t *read_message_async(event_loop *loop, int sock) { - int64_t size; - int error = read_bytes(sock, (uint8_t *) &size, sizeof(int64_t)); - if (error < 0) { - /* The other side has closed the socket. */ - RAY_LOG(DEBUG) << "Socket has been closed, or some other error has " - << "occurred."; - if (loop != NULL) { - event_loop_remove_file(loop, sock); - } - close(sock); - return NULL; - } - uint8_t *message = (uint8_t *) malloc(size); - error = read_bytes(sock, message, size); - if (error < 0) { - /* The other side has closed the socket. */ - RAY_LOG(DEBUG) << "Socket has been closed, or some other error has " - << "occurred."; - if (loop != NULL) { - event_loop_remove_file(loop, sock); - } - close(sock); - return NULL; - } - return message; -} - -int64_t read_vector(int fd, int64_t *type, std::vector &buffer) { - int64_t version; - int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - goto disconnected; - } - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); - int64_t length; - closed = read_bytes(fd, (uint8_t *) type, sizeof(*type)); - if (closed) { - goto disconnected; - } - closed = read_bytes(fd, (uint8_t *) &length, sizeof(length)); - if (closed) { - goto disconnected; - } - if (static_cast(length) > buffer.size()) { - buffer.resize(length); - } - closed = read_bytes(fd, buffer.data(), length); - if (closed) { - goto disconnected; - } - return length; -disconnected: - /* Handle the case in which the socket is closed. */ - *type = static_cast(CommonMessageType::DISCONNECT_CLIENT); - return 0; -} - -void write_log_message(int fd, const char *message) { - /* Account for the \0 at the end of the string. */ - do_write_message(fd, static_cast(CommonMessageType::LOG_MESSAGE), - strlen(message) + 1, (uint8_t *) message); -} - -char *read_log_message(int fd) { - uint8_t *bytes; - int64_t type; - int64_t length; - read_message(fd, &type, &length, &bytes); - RAY_CHECK(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - return (char *) bytes; -} diff --git a/src/common/io.h b/src/common/io.h deleted file mode 100644 index 3f976445aeb05..0000000000000 --- a/src/common/io.h +++ /dev/null @@ -1,228 +0,0 @@ -#ifndef IO_H -#define IO_H - -#include -#include - -#include -#include - -struct aeEventLoop; -typedef aeEventLoop event_loop; - -enum class CommonMessageType : int32_t { - /** Disconnect a client. */ - DISCONNECT_CLIENT, - /** Log a message from a client. */ - LOG_MESSAGE, - /** Submit a task to the local scheduler. */ - SUBMIT_TASK, -}; - -/* Helper functions for socket communication. */ - -/** - * Binds to an Internet socket at the given port. Removes any existing file at - * the pathname. Returns a non-blocking file descriptor for the socket, or -1 - * if an error occurred. - * - * @note Since the returned file descriptor is non-blocking, it is not - * recommended to use the Linux read and write calls directly, since these - * might read or write a partial message. Instead, use the provided - * write_message and read_message methods. - * - * @param port The port to bind to. - * @param shall_listen Are we also starting to listen on the socket? - * @return A non-blocking file descriptor for the socket, or -1 if an error - * occurs. - */ -int bind_inet_sock(const int port, bool shall_listen); - -/** - * Binds to a Unix domain streaming socket at the given - * pathname. Removes any existing file at the pathname. - * - * @param socket_pathname The pathname for the socket. - * @param shall_listen Are we also starting to listen on the socket? - * @return A blocking file descriptor for the socket, or -1 if an error - * occurs. - */ -int bind_ipc_sock(const char *socket_pathname, bool shall_listen); - -/** - * Connect to a Unix domain streaming socket at the given - * pathname. - * - * @param socket_pathname The pathname for the socket. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_ipc_sock(const char *socket_pathname); - -/** - * Connect to a Unix domain streaming socket at the given - * pathname, or fail after some number of retries. - * - * @param socket_pathname The pathname for the socket. - * @param num_retries The number of times to retry the connection - * before exiting. If -1 is provided, then this defaults to - * num_connect_attempts. - * @param timeout The number of milliseconds to wait in between - * retries. If -1 is provided, then this defaults to - * connect_timeout_milliseconds. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_ipc_sock_retry(const char *socket_pathname, - int num_retries, - int64_t timeout); - -/** - * Connect to an Internet socket at the given address and port. - * - * @param ip_addr The IP address to connect to. - * @param port The port number to connect to. - * - * @param socket_pathname The pathname for the socket. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_inet_sock(const char *ip_addr, int port); - -/** - * Connect to an Internet socket at the given address and port, or fail after - * some number of retries. - * - * @param ip_addr The IP address to connect to. - * @param port The port number to connect to. - * @param num_retries The number of times to retry the connection - * before exiting. If -1 is provided, then this defaults to - * num_connect_attempts. - * @param timeout The number of milliseconds to wait in between - * retries. If -1 is provided, then this defaults to - * connect_timeout_milliseconds. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_inet_sock_retry(const char *ip_addr, - int port, - int num_retries, - int64_t timeout); - -/** - * Accept a new client connection on the given socket - * descriptor. Returns a descriptor for the new socket. - */ -int accept_client(int socket_fd); - -/* Reading and writing data. */ - -/** - * Write a sequence of bytes on a file descriptor. The bytes should then be read - * by read_message. - * - * @param fd The file descriptor to write to. It can be non-blocking. - * @param version The protocol version. - * @param type The type of the message to send. - * @param length The size in bytes of the bytes parameter. - * @param bytes The address of the message to send. - * @param mutex If not NULL, the whole write operation will be locked - * with this mutex, otherwise do nothing. - * @return int Whether there was an error while writing. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int write_message(int fd, - int64_t type, - int64_t length, - uint8_t *bytes, - std::mutex *mutex = NULL); - -/** - * Read a sequence of bytes written by write_message from a file descriptor. - * This allocates space for the message. - * - * @note The caller must free the memory. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param type The type of the message that is read will be written at this - * address. If there was an error while reading, this will be - * DISCONNECT_CLIENT. - * @param length The size in bytes of the message that is read will be written - * at this address. This size does not include the bytes used to encode - * the type and length. If there was an error while reading, this will - * be 0. - * @param bytes The address at which to write the pointer to the bytes that are - * read and allocated by this function. If there was an error while - * reading, this will be NULL. - * @return Void. - */ -void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes); - -/** - * Read a message from a file descriptor and remove the file descriptor from the - * event loop if there is an error. This will actually do two reads. The first - * read reads sizeof(int64_t) bytes to determine the number of bytes to read in - * the next read. - * - * @param loop: The event loop. - * @param sock: The file descriptor to read from. - * @return A byte buffer contining the message or NULL if there was an - * error. The buffer needs to be freed by the user. - */ -uint8_t *read_message_async(event_loop *loop, int sock); - -/** - * Read a sequence of bytes written by write_message from a file descriptor. - * This does not allocate space for the message if the provided buffer is - * large enough and can therefore often avoid allocations. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param type The type of the message that is read will be written at this - * address. If there was an error while reading, this will be - * DISCONNECT_CLIENT. - * @param buffer The array the message will be written to. If it is not - * large enough to hold the message, it will be enlarged by read_vector. - * @return Number of bytes of the message that were read. This size does not - * include the bytes used to encode the type and length. If there was - * an error while reading, this will be 0. - */ -int64_t read_vector(int fd, int64_t *type, std::vector &buffer); - -/** - * Write a null-terminated string to a file descriptor. - */ -void write_log_message(int fd, const char *message); - -/** - * Reads a null-terminated string from the file descriptor that has been - * written by write_log_message. Allocates and returns a pointer to the string. - * NOTE: Caller must free the memory! - */ -char *read_log_message(int fd); - -/** - * Read a sequence of bytes from a file descriptor into a buffer. This will - * block until one of the following happens: (1) there is an error (2) end of - * file, or (3) all length bytes have been written. - * - * @note The buffer pointed to by cursor must already have length number of - * bytes allocated before calling this method. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param cursor The cursor pointing to the beginning of the buffer. - * @param length The size of the byte sequence to read. - * @return int Whether there was an error while reading. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int read_bytes(int fd, uint8_t *cursor, size_t length); - -/** - * Write a sequence of bytes into a file descriptor. This will block until one - * of the following happens: (1) there is an error (2) end of file, or (3) all - * length bytes have been written. - * - * @param fd The file descriptor to write to. It can be non-blocking. - * @param cursor The cursor pointing to the beginning of the bytes to send. - * @param length The size of the bytes sequence to write. - * @return int Whether there was an error while writing. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int write_bytes(int fd, uint8_t *cursor, size_t length); - -#endif /* IO_H */ diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc deleted file mode 100644 index 68965e270980c..0000000000000 --- a/src/common/lib/python/common_extension.cc +++ /dev/null @@ -1,919 +0,0 @@ -#include -#include "bytesobject.h" -#include "node.h" - -// Don't use the deprecated Numpy functions. -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - -#include - -#include "common.h" -#include "common_extension.h" -#include "common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/raylet/task_spec.h" -#include "ray/raylet/task_execution_spec.h" -#include "task.h" - -#include - -#if PY_MAJOR_VERSION >= 3 -#define PyInt_Check PyLong_Check -#endif - -PyObject *CommonError; - -/* Initialize pickle module. */ - -PyObject *pickle_module = NULL; -PyObject *pickle_loads = NULL; -PyObject *pickle_dumps = NULL; -PyObject *pickle_protocol = NULL; - -int init_numpy_module(void) { - import_array1(-1); - return 0; -} - -void init_pickle_module(void) { -#if PY_MAJOR_VERSION >= 3 - pickle_module = PyImport_ImportModule("pickle"); -#else - pickle_module = PyImport_ImportModuleNoBlock("cPickle"); -#endif - RAY_CHECK(pickle_module != NULL); - RAY_CHECK(PyObject_HasAttrString(pickle_module, "loads")); - RAY_CHECK(PyObject_HasAttrString(pickle_module, "dumps")); - RAY_CHECK(PyObject_HasAttrString(pickle_module, "HIGHEST_PROTOCOL")); - pickle_loads = PyUnicode_FromString("loads"); - pickle_dumps = PyUnicode_FromString("dumps"); - pickle_protocol = PyObject_GetAttrString(pickle_module, "HIGHEST_PROTOCOL"); - RAY_CHECK(pickle_protocol != NULL); -} - -TaskBuilder *g_task_builder = NULL; - -/* Define the PyObjectID class. */ - -int PyStringToUniqueID(PyObject *object, ObjectID *object_id) { - if (PyBytes_Check(object)) { - std::memcpy(object_id->mutable_data(), PyBytes_AsString(object), - sizeof(*object_id)); - return 1; - } else { - PyErr_SetString(PyExc_TypeError, "must be a 20 character string"); - return 0; - } -} - -int PyObjectToUniqueID(PyObject *object, ObjectID *objectid) { - if (PyObject_IsInstance(object, (PyObject *) &PyObjectIDType)) { - *objectid = ((PyObjectID *) object)->object_id; - return 1; - } else { - PyErr_SetString(PyExc_TypeError, "must be an ObjectID"); - return 0; - } -} - -bool use_raylet(PyTask *task) { - return task->spec == nullptr; -} - -static int PyObjectID_init(PyObjectID *self, PyObject *args, PyObject *kwds) { - const char *data; - int size; - if (!PyArg_ParseTuple(args, "s#", &data, &size)) { - return -1; - } - if (size != sizeof(ObjectID)) { - PyErr_SetString(CommonError, - "ObjectID: object id string needs to have length 20"); - return -1; - } - std::memcpy(self->object_id.mutable_data(), data, sizeof(self->object_id)); - return 0; -} - -/* Create a PyObjectID from C. */ -PyObject *PyObjectID_make(ObjectID object_id) { - PyObjectID *result = PyObject_New(PyObjectID, &PyObjectIDType); - result = (PyObjectID *) PyObject_Init((PyObject *) result, &PyObjectIDType); - result->object_id = object_id; - return (PyObject *) result; -} - -/** - * Convert a string to a Ray task specification Python object. - * - * This is called from Python like - * - * task = local_scheduler.task_from_string("...") - * - * @param task_string String representation of the task specification. - * @return Python task specification object. - */ -PyObject *PyTask_from_string(PyObject *self, PyObject *args) { - const char *data; - int size; - if (!PyArg_ParseTuple(args, "s#", &data, &size)) { - return NULL; - } - PyTask *result = PyObject_New(PyTask, &PyTaskType); - result = (PyTask *) PyObject_Init((PyObject *) result, &PyTaskType); - result->size = size; - result->spec = TaskSpec_copy((TaskSpec *) data, size); - /* The created task does not include any execution dependencies. */ - result->execution_dependencies = new std::vector(); - /* TODO(pcm): Use flatbuffers validation here. */ - return (PyObject *) result; -} - -/** - * Convert a Ray task specification Python object to a string. - * - * This is called from Python like - * - * s = local_scheduler.task_to_string(task) - * - * @param task Ray task specification Python object. - * @return String representing the task specification. - */ -PyObject *PyTask_to_string(PyObject *self, PyObject *args) { - PyObject *arg; - if (!PyArg_ParseTuple(args, "O", &arg)) { - return NULL; - } - PyTask *task = (PyTask *) arg; - if (!use_raylet(task)) { - return PyBytes_FromStringAndSize((char *) task->spec, task->size); - } else { - flatbuffers::FlatBufferBuilder fbb; - auto task_spec_string = task->task_spec->ToFlatbuffer(fbb); - fbb.Finish(task_spec_string); - return PyBytes_FromStringAndSize((char *) fbb.GetBufferPointer(), - fbb.GetSize()); - } -} - -static PyObject *PyObjectID_id(PyObject *self) { - PyObjectID *s = (PyObjectID *) self; - return PyBytes_FromStringAndSize((const char *) s->object_id.data(), - sizeof(s->object_id)); -} - -static PyObject *PyObjectID_hex(PyObject *self) { - PyObjectID *s = (PyObjectID *) self; - std::string hex_id = s->object_id.hex(); -#if PY_MAJOR_VERSION >= 3 - PyObject *result = PyUnicode_FromStringAndSize(hex_id.data(), hex_id.size()); -#else - PyObject *result = PyBytes_FromStringAndSize(hex_id.data(), hex_id.size()); -#endif - return result; -} - -static PyObject *PyObjectID_richcompare(PyObjectID *self, - PyObject *other, - int op) { - PyObject *result = NULL; - if (Py_TYPE(self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) { - result = Py_NotImplemented; - } else { - PyObjectID *other_id = (PyObjectID *) other; - switch (op) { - case Py_LT: - result = Py_NotImplemented; - break; - case Py_LE: - result = Py_NotImplemented; - break; - case Py_EQ: - result = self->object_id == other_id->object_id ? Py_True : Py_False; - break; - case Py_NE: - result = !(self->object_id == other_id->object_id) ? Py_True : Py_False; - break; - case Py_GT: - result = Py_NotImplemented; - break; - case Py_GE: - result = Py_NotImplemented; - break; - } - } - Py_XINCREF(result); - return result; -} - -static PyObject *PyObjectID_redis_shard_hash(PyObjectID *self) { - /* NOTE: The hash function used here must match the one in get_redis_context - * in src/common/state/redis.cc. Changes to the hash function should only be - * made through std::hash in src/common/common.h */ - std::hash hash; - return PyLong_FromSize_t(hash(self->object_id)); -} - -static long PyObjectID_hash(PyObjectID *self) { - // TODO(pcm): Replace this with a faster hash function. This currently - // creates a tuple of length 20 and hashes it, which is slow - PyObject *tuple = PyTuple_New(kUniqueIDSize); - for (int i = 0; i < kUniqueIDSize; ++i) { - PyTuple_SetItem(tuple, i, PyLong_FromLong(self->object_id.data()[i])); - } - long hash = PyObject_Hash(tuple); - Py_XDECREF(tuple); - return hash; -} - -static PyObject *PyObjectID_repr(PyObjectID *self) { - std::string repr = "ObjectID(" + self->object_id.hex() + ")"; - PyObject *result = PyUnicode_FromString(repr.c_str()); - return result; -} - -static PyObject *PyObjectID___reduce__(PyObjectID *self) { - PyErr_SetString(CommonError, "ObjectID objects cannot be serialized."); - return NULL; -} - -static PyMethodDef PyObjectID_methods[] = { - {"id", (PyCFunction) PyObjectID_id, METH_NOARGS, - "Return the hash associated with this ObjectID"}, - {"redis_shard_hash", (PyCFunction) PyObjectID_redis_shard_hash, METH_NOARGS, - "Return the redis shard that this ObjectID is associated with"}, - {"hex", (PyCFunction) PyObjectID_hex, METH_NOARGS, - "Return the object ID as a string in hex."}, - {"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS, - "Say how to pickle this ObjectID. This raises an exception to prevent" - "object IDs from being serialized."}, - {NULL} /* Sentinel */ -}; - -static PyMemberDef PyObjectID_members[] = { - {NULL} /* Sentinel */ -}; - -PyTypeObject PyObjectIDType = { - PyVarObject_HEAD_INIT(NULL, 0) /* ob_size */ - "common.ObjectID", /* tp_name */ - sizeof(PyObjectID), /* tp_basicsize */ - 0, /* tp_itemsize */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - (reprfunc) PyObjectID_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - (hashfunc) PyObjectID_hash, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "ObjectID object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - (richcmpfunc) PyObjectID_richcompare, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - PyObjectID_methods, /* tp_methods */ - PyObjectID_members, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc) PyObjectID_init, /* tp_init */ - 0, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ -}; - -/* Define the PyTask class. */ - -static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { - /* ID of the driver that this task originates from. */ - UniqueID driver_id; - /* ID of the actor this task should run on. */ - UniqueID actor_id = ActorID::nil(); - /* ID of the actor handle used to submit this task. */ - UniqueID actor_handle_id = ActorHandleID::nil(); - /* How many tasks have been launched on the actor so far? */ - int actor_counter = 0; - /* True if this is an actor checkpoint task and false otherwise. */ - PyObject *is_actor_checkpoint_method_object = nullptr; - /* ID of the function this task executes. */ - FunctionID function_id; - /* Arguments of the task (can be PyObjectIDs or Python values). */ - PyObject *arguments; - /* Number of return values of this task. */ - int num_returns; - /* The ID of the task that called this task. */ - TaskID parent_task_id; - /* The number of tasks that the parent task has called prior to this one. */ - int parent_counter; - // The actor creation ID. - ActorID actor_creation_id = ActorID::nil(); - // The dummy object for the actor creation task (if this is an actor method). - ObjectID actor_creation_dummy_object_id = ObjectID::nil(); - /* Arguments of the task that are execution-dependent. These must be - * PyObjectIDs). */ - PyObject *execution_arguments = nullptr; - /* Dictionary of resource requirements for this task. */ - PyObject *resource_map = nullptr; - // True if we should use the raylet code path and false otherwise. - PyObject *use_raylet_object = nullptr; - if (!PyArg_ParseTuple( - args, "O&O&OiO&i|O&O&O&O&iOOOO", &PyObjectToUniqueID, &driver_id, - &PyObjectToUniqueID, &function_id, &arguments, &num_returns, - &PyObjectToUniqueID, &parent_task_id, &parent_counter, - &PyObjectToUniqueID, &actor_creation_id, &PyObjectToUniqueID, - &actor_creation_dummy_object_id, &PyObjectToUniqueID, &actor_id, - &PyObjectToUniqueID, &actor_handle_id, &actor_counter, - &is_actor_checkpoint_method_object, &execution_arguments, - &resource_map, &use_raylet_object)) { - return -1; - } - - bool is_actor_checkpoint_method = false; - if (is_actor_checkpoint_method_object != nullptr && - PyObject_IsTrue(is_actor_checkpoint_method_object) == 1) { - is_actor_checkpoint_method = true; - } - - // Parse the resource map. - std::unordered_map required_resources; - - bool found_CPU_requirements = false; - PyObject *key, *value; - Py_ssize_t position = 0; - if (resource_map != nullptr) { - if (!PyDict_Check(resource_map)) { - PyErr_SetString(PyExc_TypeError, "resource_map must be a dictionary"); - return -1; - } - while (PyDict_Next(resource_map, &position, &key, &value)) { - if (!(PyBytes_Check(key) || PyUnicode_Check(key))) { - PyErr_SetString(PyExc_TypeError, - "the keys in resource_map must be strings"); - return -1; - } - if (!(PyFloat_Check(value) || PyInt_Check(value) || - PyLong_Check(value))) { - PyErr_SetString(PyExc_TypeError, - "the values in resource_map must be floats"); - return -1; - } - // Handle the case where the key is a bytes object and the case where it - // is a unicode object. - std::string resource_name; - if (PyUnicode_Check(key)) { - PyObject *ascii_key = PyUnicode_AsASCIIString(key); - resource_name = - std::string(PyBytes_AsString(ascii_key), PyBytes_Size(ascii_key)); - Py_DECREF(ascii_key); - } else { - resource_name = std::string(PyBytes_AsString(key), PyBytes_Size(key)); - } - if (resource_name == std::string("CPU")) { - found_CPU_requirements = true; - } - required_resources[resource_name] = PyFloat_AsDouble(value); - } - } - if (!found_CPU_requirements) { - required_resources["CPU"] = 1.0; - } - - Py_ssize_t num_args = PyList_Size(arguments); - - bool use_raylet = false; - if (use_raylet_object != nullptr && PyObject_IsTrue(use_raylet_object) == 1) { - use_raylet = true; - } - self->spec = nullptr; - self->task_spec = nullptr; - - // Create the task spec. - if (!use_raylet) { - // The non-raylet code path. - - // Construct the task specification. - TaskSpec_start_construct( - g_task_builder, driver_id, parent_task_id, parent_counter, - actor_creation_id, actor_creation_dummy_object_id, actor_id, - actor_handle_id, actor_counter, is_actor_checkpoint_method, function_id, - num_returns); - // Add the task arguments. - for (Py_ssize_t i = 0; i < num_args; ++i) { - PyObject *arg = PyList_GetItem(arguments, i); - if (PyObject_IsInstance(arg, - reinterpret_cast(&PyObjectIDType))) { - TaskSpec_args_add_ref(g_task_builder, - &(reinterpret_cast(arg))->object_id, - 1); - } else { - PyObject *data = PyObject_CallMethodObjArgs(pickle_module, pickle_dumps, - arg, pickle_protocol, NULL); - TaskSpec_args_add_val( - g_task_builder, reinterpret_cast(PyBytes_AsString(data)), - PyBytes_Size(data)); - Py_DECREF(data); - } - } - // Set the resource requirements for the task. - for (auto const &resource_pair : required_resources) { - TaskSpec_set_required_resource(g_task_builder, resource_pair.first, - resource_pair.second); - } - - // Compute the task ID and the return object IDs. - self->spec = TaskSpec_finish_construct(g_task_builder, &self->size); - - } else { - // The raylet code path. - - // Parse the arguments from the list. - std::vector> args; - for (Py_ssize_t i = 0; i < num_args; ++i) { - PyObject *arg = PyList_GetItem(arguments, i); - if (PyObject_IsInstance(arg, - reinterpret_cast(&PyObjectIDType))) { - std::vector references = { - reinterpret_cast(arg)->object_id}; - args.push_back( - std::make_shared(references)); - } else { - PyObject *data = PyObject_CallMethodObjArgs(pickle_module, pickle_dumps, - arg, pickle_protocol, NULL); - args.push_back(std::make_shared( - reinterpret_cast(PyBytes_AsString(data)), - PyBytes_Size(data))); - Py_DECREF(data); - } - } - - self->task_spec = new ray::raylet::TaskSpecification( - driver_id, parent_task_id, parent_counter, actor_creation_id, - actor_creation_dummy_object_id, actor_id, actor_handle_id, - actor_counter, function_id, args, num_returns, required_resources, - Language::PYTHON); - } - - /* Set the task's execution dependencies. */ - self->execution_dependencies = new std::vector(); - if (execution_arguments != NULL) { - Py_ssize_t num_execution_args = PyList_Size(execution_arguments); - for (Py_ssize_t i = 0; i < num_execution_args; ++i) { - PyObject *execution_arg = PyList_GetItem(execution_arguments, i); - if (!PyObject_IsInstance(execution_arg, (PyObject *) &PyObjectIDType)) { - PyErr_SetString(PyExc_TypeError, - "Execution arguments must be an ObjectID."); - return -1; - } - self->execution_dependencies->push_back( - ((PyObjectID *) execution_arg)->object_id); - } - } - - return 0; -} - -static void PyTask_dealloc(PyTask *self) { - if (!use_raylet(self)) { - TaskSpec_free(self->spec); - } else { - delete self->task_spec; - } - delete self->execution_dependencies; - Py_TYPE(self)->tp_free(reinterpret_cast(self)); -} - -static PyObject *PyTask_function_id(PyTask *self) { - FunctionID function_id; - if (!use_raylet(self)) { - function_id = TaskSpec_function(self->spec); - } else { - function_id = self->task_spec->FunctionId(); - } - return PyObjectID_make(function_id); -} - -static PyObject *PyTask_actor_id(PyTask *self) { - ActorID actor_id; - if (!use_raylet(self)) { - actor_id = TaskSpec_actor_id(self->spec); - } else { - actor_id = self->task_spec->ActorId(); - } - return PyObjectID_make(actor_id); -} - -static PyObject *PyTask_actor_counter(PyTask *self) { - int64_t actor_counter; - if (!use_raylet(self)) { - actor_counter = TaskSpec_actor_counter(self->spec); - } else { - actor_counter = self->task_spec->ActorCounter(); - } - return PyLong_FromLongLong(actor_counter); -} - -static PyObject *PyTask_driver_id(PyTask *self) { - UniqueID driver_id; - if (!use_raylet(self)) { - driver_id = TaskSpec_driver_id(self->spec); - } else { - driver_id = self->task_spec->DriverId(); - } - return PyObjectID_make(driver_id); -} - -static PyObject *PyTask_task_id(PyTask *self) { - TaskID task_id; - if (!use_raylet(self)) { - task_id = TaskSpec_task_id(self->spec); - } else { - task_id = self->task_spec->TaskId(); - } - return PyObjectID_make(task_id); -} - -static PyObject *PyTask_parent_task_id(PyTask *self) { - TaskID task_id; - if (!use_raylet(self)) { - task_id = TaskSpec_parent_task_id(self->spec); - } else { - task_id = self->task_spec->ParentTaskId(); - } - return PyObjectID_make(task_id); -} - -static PyObject *PyTask_parent_counter(PyTask *self) { - int64_t parent_counter; - if (!use_raylet(self)) { - parent_counter = TaskSpec_parent_counter(self->spec); - } else { - parent_counter = self->task_spec->ParentCounter(); - } - return PyLong_FromLongLong(parent_counter); -} - -static PyObject *PyTask_arguments(PyTask *self) { - TaskSpec *task = self->spec; - ray::raylet::TaskSpecification *task_spec = self->task_spec; - - int64_t num_args; - if (!use_raylet(self)) { - num_args = TaskSpec_num_args(task); - } else { - num_args = self->task_spec->NumArgs(); - } - - PyObject *arg_list = PyList_New((Py_ssize_t) num_args); - for (int i = 0; i < num_args; ++i) { - int count; - if (!use_raylet(self)) { - count = TaskSpec_arg_id_count(task, i); - } else { - count = task_spec->ArgIdCount(i); - } - - if (count > 0) { - assert(count == 1); - - ObjectID object_id; - if (!use_raylet(self)) { - object_id = TaskSpec_arg_id(task, i, 0); - } else { - object_id = task_spec->ArgId(i, 0); - } - - PyList_SetItem(arg_list, i, PyObjectID_make(object_id)); - } else { - RAY_CHECK(pickle_module != NULL); - RAY_CHECK(pickle_loads != NULL); - - const uint8_t *arg_val; - int64_t arg_length; - if (!use_raylet(self)) { - arg_val = TaskSpec_arg_val(task, i); - arg_length = TaskSpec_arg_length(task, i); - } else { - arg_val = task_spec->ArgVal(i); - arg_length = task_spec->ArgValLength(i); - } - - PyObject *str = - PyBytes_FromStringAndSize(reinterpret_cast(arg_val), - static_cast(arg_length)); - PyObject *val = - PyObject_CallMethodObjArgs(pickle_module, pickle_loads, str, NULL); - Py_XDECREF(str); - PyList_SetItem(arg_list, i, val); - } - } - return arg_list; -} - -static PyObject *PyTask_actor_creation_id(PyTask *self) { - ActorID actor_creation_id; - if (!use_raylet(self)) { - actor_creation_id = TaskSpec_actor_creation_id(self->spec); - } else { - actor_creation_id = self->task_spec->ActorCreationId(); - } - return PyObjectID_make(actor_creation_id); -} - -static PyObject *PyTask_actor_creation_dummy_object_id(PyTask *self) { - ObjectID actor_creation_dummy_object_id; - if (!use_raylet(self)) { - if (TaskSpec_is_actor_task(self->spec)) { - actor_creation_dummy_object_id = - TaskSpec_actor_creation_dummy_object_id(self->spec); - } else { - actor_creation_dummy_object_id = ObjectID::nil(); - } - } else { - actor_creation_dummy_object_id = - self->task_spec->ActorCreationDummyObjectId(); - } - return PyObjectID_make(actor_creation_dummy_object_id); -} - -static PyObject *PyTask_required_resources(PyTask *self) { - PyObject *required_resources = PyDict_New(); - - std::unordered_map resource_map; - if (!use_raylet(self)) { - resource_map = TaskSpec_get_required_resources(self->spec); - } else { - resource_map = self->task_spec->GetRequiredResources().GetResourceMap(); - } - - for (auto const &resource_pair : resource_map) { - std::string resource_name = resource_pair.first; -#if PY_MAJOR_VERSION >= 3 - PyObject *key = - PyUnicode_FromStringAndSize(resource_name.data(), resource_name.size()); -#else - PyObject *key = - PyBytes_FromStringAndSize(resource_name.data(), resource_name.size()); -#endif - PyObject *value = PyFloat_FromDouble(resource_pair.second); - PyDict_SetItem(required_resources, key, value); - Py_DECREF(key); - Py_DECREF(value); - } - return required_resources; -} - -static PyObject *PyTask_returns(PyTask *self) { - TaskSpec *task = self->spec; - ray::raylet::TaskSpecification *task_spec = self->task_spec; - - int64_t num_returns; - if (!use_raylet(self)) { - num_returns = TaskSpec_num_returns(task); - } else { - num_returns = task_spec->NumReturns(); - } - - PyObject *return_id_list = PyList_New((Py_ssize_t) num_returns); - for (int i = 0; i < num_returns; ++i) { - ObjectID object_id; - if (!use_raylet(self)) { - object_id = TaskSpec_return(task, i); - } else { - object_id = task_spec->ReturnId(i); - } - PyList_SetItem(return_id_list, i, PyObjectID_make(object_id)); - } - return return_id_list; -} - -static PyObject *PyTask_execution_dependencies_string(PyTask *self) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = CreateTaskExecutionDependencies( - fbb, to_flatbuf(fbb, *self->execution_dependencies)); - fbb.Finish(execution_dependencies); - return PyBytes_FromStringAndSize((char *) fbb.GetBufferPointer(), - fbb.GetSize()); -} - -static PyObject *PyTask_to_serialized_flatbuf(PyTask *self) { - RAY_CHECK(use_raylet(self)); - - const std::vector execution_dependencies( - *self->execution_dependencies); - auto const execution_spec = ray::raylet::TaskExecutionSpecification( - std::move(execution_dependencies)); - auto const task = ray::raylet::Task(execution_spec, *self->task_spec); - - flatbuffers::FlatBufferBuilder fbb; - auto task_flatbuffer = task.ToFlatbuffer(fbb); - fbb.Finish(task_flatbuffer); - - return PyBytes_FromStringAndSize( - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); -} - -static PyMethodDef PyTask_methods[] = { - {"function_id", (PyCFunction) PyTask_function_id, METH_NOARGS, - "Return the function ID for this task."}, - {"parent_task_id", (PyCFunction) PyTask_parent_task_id, METH_NOARGS, - "Return the task ID of the parent task."}, - {"parent_counter", (PyCFunction) PyTask_parent_counter, METH_NOARGS, - "Return the parent counter of this task."}, - {"actor_id", (PyCFunction) PyTask_actor_id, METH_NOARGS, - "Return the actor ID for this task."}, - {"actor_counter", (PyCFunction) PyTask_actor_counter, METH_NOARGS, - "Return the actor counter for this task."}, - {"driver_id", (PyCFunction) PyTask_driver_id, METH_NOARGS, - "Return the driver ID for this task."}, - {"task_id", (PyCFunction) PyTask_task_id, METH_NOARGS, - "Return the task ID for this task."}, - {"arguments", (PyCFunction) PyTask_arguments, METH_NOARGS, - "Return the arguments for the task."}, - {"actor_creation_id", (PyCFunction) PyTask_actor_creation_id, METH_NOARGS, - "Return the actor creation ID for the task."}, - {"actor_creation_dummy_object_id", - (PyCFunction) PyTask_actor_creation_dummy_object_id, METH_NOARGS, - "Return the actor creation dummy object ID for the task."}, - {"required_resources", (PyCFunction) PyTask_required_resources, METH_NOARGS, - "Return the resource vector of the task."}, - {"returns", (PyCFunction) PyTask_returns, METH_NOARGS, - "Return the object IDs for the return values of the task."}, - {"execution_dependencies_string", - (PyCFunction) PyTask_execution_dependencies_string, METH_NOARGS, - "Return the execution dependencies for the task as a string."}, - {"_serialized_raylet_task", (PyCFunction) PyTask_to_serialized_flatbuf, - METH_NOARGS, - "This is a hack used to create a serialized flatbuffer object for the " - "driver task. We're doing this because creating the flatbuffer object in " - "Python didn't seem to work."}, - {NULL} /* Sentinel */ -}; - -PyTypeObject PyTaskType = { - PyVarObject_HEAD_INIT(NULL, 0) /* ob_size */ - "task.Task", /* tp_name */ - sizeof(PyTask), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor) PyTask_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "Task object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - PyTask_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc) PyTask_init, /* tp_init */ - 0, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ -}; - -/* Create a PyTask from a C struct. The resulting PyTask takes ownership of the - * TaskSpec and will deallocate the TaskSpec in the PyTask destructor. */ -PyObject *PyTask_make(TaskSpec *task_spec, int64_t task_size) { - PyTask *result = PyObject_New(PyTask, &PyTaskType); - result = (PyTask *) PyObject_Init((PyObject *) result, &PyTaskType); - result->spec = task_spec; - result->size = task_size; - /* The created task does not include any execution dependencies. */ - result->execution_dependencies = new std::vector(); - return (PyObject *) result; -} - -/* Define the methods for the module. */ - -/** - * This method checks if a Python object is sufficiently simple that it can be - * serialized and passed by value as an argument to a task (without being put in - * the object store). The details of which objects are sufficiently simple are - * defined by this method and are not particularly important. But for - * performance reasons, it is better to place "small" objects in the task itself - * and "large" objects in the object store. - * - * @param value The Python object in question. - * @param num_elements_contained If this method returns 1, then the number of - * objects recursively contained within this object will be added to the - * value at this address. This is used to make sure that we do not - * serialize objects that are too large. - * @return False if the object cannot be serialized in the task and true if it - * can. - */ -bool is_simple_value(PyObject *value, int *num_elements_contained) { - *num_elements_contained += 1; - if (*num_elements_contained >= RayConfig::instance().num_elements_limit()) { - return false; - } - if (PyInt_Check(value) || PyLong_Check(value) || value == Py_False || - value == Py_True || PyFloat_Check(value) || value == Py_None) { - return true; - } - if (PyBytes_CheckExact(value)) { - *num_elements_contained += PyBytes_Size(value); - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - if (PyUnicode_CheckExact(value)) { - *num_elements_contained += PyUnicode_GET_SIZE(value); - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - if (PyList_CheckExact(value) && - PyList_Size(value) < RayConfig::instance().size_limit()) { - for (Py_ssize_t i = 0; i < PyList_Size(value); ++i) { - if (!is_simple_value(PyList_GetItem(value, i), num_elements_contained)) { - return false; - } - } - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - if (PyDict_CheckExact(value) && - PyDict_Size(value) < RayConfig::instance().size_limit()) { - PyObject *key, *val; - Py_ssize_t pos = 0; - while (PyDict_Next(value, &pos, &key, &val)) { - if (!is_simple_value(key, num_elements_contained) || - !is_simple_value(val, num_elements_contained)) { - return false; - } - } - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - if (PyTuple_CheckExact(value) && - PyTuple_Size(value) < RayConfig::instance().size_limit()) { - for (Py_ssize_t i = 0; i < PyTuple_Size(value); ++i) { - if (!is_simple_value(PyTuple_GetItem(value, i), num_elements_contained)) { - return false; - } - } - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - if (PyArray_CheckExact(value)) { - PyArrayObject *array = reinterpret_cast(value); - if (PyArray_TYPE(array) == NPY_OBJECT) { - return false; - } - *num_elements_contained += PyArray_NBYTES(array); - return (*num_elements_contained < - RayConfig::instance().num_elements_limit()); - } - return false; -} - -PyObject *check_simple_value(PyObject *self, PyObject *args) { - PyObject *value; - if (!PyArg_ParseTuple(args, "O", &value)) { - return NULL; - } - int num_elements_contained = 0; - if (is_simple_value(value, &num_elements_contained)) { - Py_RETURN_TRUE; - } - Py_RETURN_FALSE; -} - -PyObject *compute_task_id(PyObject *self, PyObject *args) { - ObjectID object_id; - if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, &object_id)) { - return NULL; - } - TaskID task_id = ray::ComputeTaskId(object_id); - return PyObjectID_make(task_id); -} diff --git a/src/common/logging.cc b/src/common/logging.cc deleted file mode 100644 index 9802dd3d03f3a..0000000000000 --- a/src/common/logging.cc +++ /dev/null @@ -1,107 +0,0 @@ -#include "logging.h" - -#include -#include -#include - -#include - -#include "state/redis.h" -#include "io.h" -#include -#include - -static const char *log_levels[5] = {"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}; -static const char *log_fmt = - "HMSET log:%s:%s log_level %s event_type %s message %s timestamp %s"; - -struct RayLoggerImpl { - /* String that identifies this client type. */ - const char *client_type; - /* Suppress all log messages below this level. */ - int log_level; - /* Whether or not we have a direct connection to Redis. */ - int is_direct; - /* Either a db_handle or a socket to a process with a db_handle, - * depending on the is_direct flag. */ - void *conn; -}; - -RayLogger *RayLogger_init(const char *client_type, - int log_level, - int is_direct, - void *conn) { - RayLogger *logger = (RayLogger *) malloc(sizeof(RayLogger)); - logger->client_type = client_type; - logger->log_level = log_level; - logger->is_direct = is_direct; - logger->conn = conn; - return logger; -} - -void RayLogger_free(RayLogger *logger) { - free(logger); -} - -void RayLogger_log(RayLogger *logger, - int log_level, - const char *event_type, - const char *message) { - if (log_level < logger->log_level) { - return; - } - if (log_level < RAY_LOG_DEBUG || log_level > RAY_LOG_FATAL) { - return; - } - struct timeval tv; - gettimeofday(&tv, NULL); - std::string timestamp = - std::to_string(tv.tv_sec) + "." + std::to_string(tv.tv_usec); - - /* Find number of bytes that would have been written for formatted_message - * size */ - size_t formatted_message_size = - std::snprintf(nullptr, 0, log_fmt, timestamp.c_str(), "%b", - log_levels[log_level], event_type, message, - timestamp.c_str()) + - 1; - /* Fill out everything except the client ID, which is binary data. */ - char formatted_message[formatted_message_size]; - std::snprintf(formatted_message, formatted_message_size, log_fmt, - timestamp.c_str(), "%b", log_levels[log_level], event_type, - message, timestamp.c_str()); - - if (logger->is_direct) { - DBHandle *db = (DBHandle *) logger->conn; - /* Fill in the client ID and send the message to Redis. */ - - redisAsyncContext *context = get_redis_context(db, db->client); - - int status = - redisAsyncCommand(context, NULL, NULL, formatted_message, - (char *) db->client.data(), sizeof(db->client)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error while logging message to log table"); - } - } else { - /* If we don't own a Redis connection, we leave our client - * ID to be filled in by someone else. */ - int *socket_fd = (int *) logger->conn; - write_log_message(*socket_fd, formatted_message); - } -} - -void RayLogger_log_event(DBHandle *db, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp) { - std::string timestamp_string = std::to_string(timestamp); - int status = redisAsyncCommand(db->context, NULL, NULL, "ZADD %b %s %b", key, - key_length, timestamp_string.c_str(), value, - value_length); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error while logging message to event log"); - } -} diff --git a/src/common/logging.h b/src/common/logging.h deleted file mode 100644 index 1fa57a60c7123..0000000000000 --- a/src/common/logging.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef LOGGING_H -#define LOGGING_H - -#define RAY_LOG_VERBOSE -1 -#define RAY_LOG_DEBUG 0 -#define RAY_LOG_INFO 1 -#define RAY_LOG_WARNING 2 -#define RAY_LOG_ERROR 3 -#define RAY_LOG_FATAL 4 - -/* Entity types. */ -#define RAY_FUNCTION "FUNCTION" -#define RAY_OBJECT "OBJECT" -#define RAY_TASK "TASK" - -#include "state/db.h" - -typedef struct RayLoggerImpl RayLogger; - -/* Initialize a Ray logger for the given client type and logging level. If the - * is_direct flag is set, the logger will treat the given connection as a - * direct connection to the log. Otherwise, it will treat it as a socket to - * another process with a connection to the log. - * NOTE: User is responsible for freeing the returned logger. */ -RayLogger *RayLogger_init(const char *client_type, - int log_level, - int is_direct, - void *conn); - -/* Free the logger. This does not free the connection to the log. */ -void RayLogger_free(RayLogger *logger); - -/* Log an event at the given log level with the given event_type. - * NOTE: message cannot contain spaces! JSON format is recommended. - * TODO: Support spaces in messages. */ -void RayLogger_log(RayLogger *logger, - int log_level, - const char *event_type, - const char *message); - -/** - * Log an event to the event log. - * - * @param db The database handle. - * @param key The key in Redis to store the event in. - * @param key_length The length of the key. - * @param value The value to log. - * @param value_length The length of the value. - * @return Void. - */ -void RayLogger_log_event(DBHandle *db, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double time); - -#endif /* LOGGING_H */ diff --git a/src/common/net.cc b/src/common/net.cc deleted file mode 100644 index 3f2aaf6fa94e5..0000000000000 --- a/src/common/net.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "net.h" - -#include - -#include - -#include "common.h" - -int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) { - char port_str[6]; - int parsed = sscanf(ip_addr_port, "%15[0-9.]:%5[0-9]", ip_addr, port_str); - if (parsed != 2) { - return -1; - } - *port = atoi(port_str); - return 0; -} - -/* Return true if the ip address is valid. */ -bool valid_ip_address(const std::string &ip_address) { - struct sockaddr_in sa; - int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr); - return result == 1; -} diff --git a/src/common/net.h b/src/common/net.h deleted file mode 100644 index 109cdf3fa1f33..0000000000000 --- a/src/common/net.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef NET_H -#define NET_H - -/* Helper function to parse a string of the form : into the - * given ip_addr and port pointers. The ip_addr buffer must already be - * allocated. Return 0 upon success and -1 upon failure. */ -int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port); - -#endif /* NET_H */ diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc deleted file mode 100644 index d594f74effef0..0000000000000 --- a/src/common/redis_module/ray_redis_module.cc +++ /dev/null @@ -1,1886 +0,0 @@ -#include - -#include "common_protocol.h" -#include "format/common_generated.h" -#include "ray/gcs/format/gcs_generated.h" -#include "ray/id.h" -#include "redis_string.h" -#include "redismodule.h" -#include "task.h" - -#if RAY_USE_NEW_GCS -// Under this flag, ray-project/credis will be loaded. Specifically, via -// "path/redis-server --loadmodule --loadmodule " (dlopen() under the hood) will a definition of "module" -// be supplied. -// -// All commands in this file that depend on "module" must be wrapped by "#if -// RAY_USE_NEW_GCS", until we switch to this launch configuration as the -// default. -#include "chain_module.h" -extern RedisChainModule module; -#endif - -// Various tables are maintained in redis: -// -// == OBJECT TABLE == -// -// This consists of two parts: -// - The object location table, indexed by OL:object_id, which is the set of -// plasma manager indices that have access to the object. -// (In redis this is represented by a zset (sorted set).) -// -// - The object info table, indexed by OI:object_id, which is a hashmap of: -// "hash" -> the hash of the object, -// "data_size" -> the size of the object in bytes, -// "task" -> the task ID that generated this object. -// "is_put" -> 0 or 1. -// -// == TASK TABLE == -// -// It maps each TT:task_id to a hash: -// "state" -> the state of the task, encoded as a bit mask of scheduling_state -// enum values in task.h, -// "local_scheduler_id" -> the ID of the local scheduler the task is assigned -// to, -// "TaskSpec" -> serialized bytes of a TaskInfo (defined in common.fbs), which -// describes the details this task. -// -// See also the definition of TaskReply in common.fbs. - -#define OBJECT_INFO_PREFIX "OI:" -#define OBJECT_LOCATION_PREFIX "OL:" -#define OBJECT_NOTIFICATION_PREFIX "ON:" -#define TASK_PREFIX "TT:" -#define OBJECT_BCAST "BCAST" - -#define OBJECT_CHANNEL_PREFIX "OC:" - -#define CHECK_ERROR(STATUS, MESSAGE) \ - if ((STATUS) == REDISMODULE_ERR) { \ - return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ - } - -/// Parse a Redis string into a TablePubsub channel. -TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) { - long long pubsub_channel_long; - RAY_CHECK(RedisModule_StringToLongLong( - pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK) - << "Pubsub channel must be a valid TablePubsub"; - auto pubsub_channel = static_cast(pubsub_channel_long); - RAY_CHECK(pubsub_channel >= TablePubsub::MIN && - pubsub_channel <= TablePubsub::MAX) - << "Pubsub channel must be a valid TablePubsub"; - return pubsub_channel; -} - -/// Format a pubsub channel for a specific key. pubsub_channel_str should -/// contain a valid TablePubsub. -RedisModuleString *FormatPubsubChannel( - RedisModuleCtx *ctx, - const RedisModuleString *pubsub_channel_str, - const RedisModuleString *id) { - // Format the pubsub channel enum to a string. TablePubsub_MAX should be more - // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; - sprintf(pubsub_channel, "%d", - static_cast(ParseTablePubsub(pubsub_channel_str))); - return RedisString_Format(ctx, "%s:%S", pubsub_channel, id); -} - -// TODO(swang): This helper function should be deprecated by the version below, -// which uses enums for table prefixes. -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - const char *prefix, - RedisModuleString *keyname, - int mode, - RedisModuleString **mutated_key_str) { - RedisModuleString *prefixed_keyname = - RedisString_Format(ctx, "%s%S", prefix, keyname); - // Pass out the key being mutated, should the caller request so. - if (mutated_key_str != nullptr) { - *mutated_key_str = prefixed_keyname; - } - RedisModuleKey *key = - (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); - return key; -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - RedisModuleString *prefix_enum, - RedisModuleString *keyname, - int mode, - RedisModuleString **mutated_key_str) { - long long prefix_long; - RAY_CHECK(RedisModule_StringToLongLong(prefix_enum, &prefix_long) == - REDISMODULE_OK) - << "Prefix must be a valid TablePrefix"; - auto prefix = static_cast(prefix_long); - RAY_CHECK(prefix != TablePrefix::UNUSED) - << "This table has no prefix registered"; - RAY_CHECK(prefix >= TablePrefix::MIN && prefix <= TablePrefix::MAX) - << "Prefix must be a valid TablePrefix"; - return OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, - mutated_key_str); -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - const char *prefix, - RedisModuleString *keyname, - int mode) { - return OpenPrefixedKey(ctx, prefix, keyname, mode, - /*mutated_key_str=*/nullptr); -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - RedisModuleString *prefix_enum, - RedisModuleString *keyname, - int mode) { - return OpenPrefixedKey(ctx, prefix_enum, keyname, mode, - /*mutated_key_str=*/nullptr); -} - -/// Open the key used to store the channels that should be published to when an -/// update happens at the given keyname. -RedisModuleKey *OpenBroadcastKey(RedisModuleCtx *ctx, - RedisModuleString *pubsub_channel_str, - RedisModuleString *keyname, - int mode) { - RedisModuleString *channel = - FormatPubsubChannel(ctx, pubsub_channel_str, keyname); - RedisModuleString *prefixed_keyname = - RedisString_Format(ctx, "BCAST:%S", channel); - RedisModuleKey *key = - (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); - return key; -} - -/** - * This is a helper method to convert a redis module string to a flatbuffer - * string. - * - * @param fbb The flatbuffer builder. - * @param redis_string The redis string. - * @return The flatbuffer string. - */ -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, - RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); -} - -/** - * Publish a notification to a client's notification channel about an insertion - * or deletion to the db client table. - * - * TODO(swang): Use flatbuffers for the notification message. - * The format for the published notification is: - * : - * If no manager address is provided, manager_address will be set to ":". If - * is_insertion is true, then the last field will be "1", else "0". - * - * @param ctx The Redis context. - * @param ray_client_id The ID of the database client that was inserted or - * deleted. - * @param client_type The type of client that was inserted or deleted. - * @param manager_address An optional secondary address for the object manager - * associated with this database client. - * @param is_insertion A boolean that's true if the update was an insertion and - * false if deletion. - * @return True if the publish was successful and false otherwise. - */ -bool PublishDBClientNotification(RedisModuleCtx *ctx, - RedisModuleString *ray_client_id, - RedisModuleString *client_type, - RedisModuleString *manager_address, - bool is_insertion) { - /* Construct strings to publish on the db client channel. */ - RedisModuleString *channel_name = - RedisModule_CreateString(ctx, "db_clients", strlen("db_clients")); - /* Construct the flatbuffers object to publish over the channel. */ - flatbuffers::FlatBufferBuilder fbb; - /* Use an empty aux address if one is not passed in. */ - flatbuffers::Offset manager_address_str; - if (manager_address != NULL) { - manager_address_str = RedisStringToFlatbuf(fbb, manager_address); - } else { - manager_address_str = fbb.CreateString("", strlen("")); - } - /* Create the flatbuffers message. */ - auto message = CreateSubscribeToDBClientTableReply( - fbb, RedisStringToFlatbuf(fbb, ray_client_id), - RedisStringToFlatbuf(fbb, client_type), manager_address_str, - is_insertion); - fbb.Finish(message); - /* Create a Redis string to publish by serializing the flatbuffers object. */ - RedisModuleString *client_info = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - /* Publish the client info on the db client channel. */ - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, client_info); - return (reply != NULL); -} - -/** - * Register a client with Redis. This is called from a client with the command: - * - * RAY.CONNECT - * ... - * - * The command can take an arbitrary number of pairs of field names and keys, - * and these will be stored in a hashmap associated with this client. Several - * fields are singled out for special treatment: - * - * manager_address: This is provided by local schedulers and plasma - * managers and should be the address of the plasma manager that the - * client is associated with. This is published to the "db_clients" - * channel by the RAY.CONNECT command. - * - * @param ray_client_id The db client ID of the client. - * @param node_ip_address The IP address of the node the client is on. - * @param client_type The type of the client (e.g., plasma_manager). - * @return OK if the operation was successful. - */ -int Connect_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 4) { - return RedisModule_WrongArity(ctx); - } - if (argc % 2 != 0) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *ray_client_id = argv[1]; - RedisModuleString *node_ip_address = argv[2]; - RedisModuleString *client_type = argv[3]; - - /* Add this client to the Ray db client table. */ - RedisModuleKey *db_client_table_key = - OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); - - if (RedisModule_KeyType(db_client_table_key) != REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "Client already exists"); - } - - /* This will be used to construct a publish message. */ - RedisModuleString *manager_address = NULL; - RedisModuleString *manager_address_key = RedisModule_CreateString( - ctx, "manager_address", strlen("manager_address")); - RedisModuleString *deleted = RedisModule_CreateString(ctx, "0", strlen("0")); - - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "ray_client_id", ray_client_id, "node_ip_address", - node_ip_address, "client_type", client_type, "deleted", - deleted, NULL); - - for (int i = 4; i < argc; i += 2) { - RedisModuleString *key = argv[i]; - RedisModuleString *value = argv[i + 1]; - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_NONE, key, value, - NULL); - if (RedisModule_StringCompare(key, manager_address_key) == 0) { - manager_address = value; - } - } - /* Clean up. */ - if (!PublishDBClientNotification(ctx, ray_client_id, client_type, - manager_address, true)) { - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Remove a client from Redis. This is called from a client with the command: - * - * RAY.DISCONNECT - * - * This method also publishes a notification to all subscribers to the - * db_clients channel. The notification consists of a message of the form ":". - * - * @param ray_client_id The db client ID of the client. - * @return OK if the operation was successful. - */ -int Disconnect_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *ray_client_id = argv[1]; - - /* Get the client type. */ - RedisModuleKey *db_client_table_key = - OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); - - RedisModuleString *deleted_string; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, "deleted", - &deleted_string, NULL); - long long deleted; - int parsed = RedisModule_StringToLongLong(deleted_string, &deleted); - if (parsed != REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "Unable to parse deleted field"); - } - - bool published = true; - if (deleted == 0) { - /* Remove the client from the client table. */ - RedisModuleString *deleted = - RedisModule_CreateString(ctx, "1", strlen("1")); - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "deleted", deleted, NULL); - - RedisModuleString *client_type; - RedisModuleString *manager_address; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "client_type", &client_type, "manager_address", - &manager_address, NULL); - - /* Publish the deletion notification on the db client channel. */ - published = PublishDBClientNotification(ctx, ray_client_id, client_type, - manager_address, false); - } - - if (!published) { - /* Return an error message if we weren't able to publish the deletion - * notification. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Lookup an entry in the object table. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_LOOKUP - * - * @param object_id A string representing the object ID. - * @return A list, possibly empty, of plasma manager IDs that are listed in the - * object table as having the object. If there was no entry found in - * the object table, returns nil. - */ -int ObjectTableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleKey *key = - OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, argv[1], REDISMODULE_READ); - - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - /* Return nil if no entry was found. */ - return RedisModule_ReplyWithNull(ctx); - } - if (RedisModule_ValueLength(key) == 0) { - /* Return empty list if there are no managers. */ - return RedisModule_ReplyWithArray(ctx, 0); - } - - CHECK_ERROR( - RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - - RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN); - int num_results = 0; - do { - RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL); - RedisModule_ReplyWithString(ctx, curr); - num_results += 1; - } while (RedisModule_ZsetRangeNext(key)); - RedisModule_ReplySetArrayLength(ctx, num_results); - - return REDISMODULE_OK; -} - -/** - * Publish a notification to a client's object notification channel if at least - * one manager is listed as having the object in the object table. - * - * @param ctx The Redis context. - * @param client_id The ID of the client that is being notified. - * @param object_id The object ID of interest. - * @param key The opened key for the entry in the object table corresponding to - * the object ID of interest. - * @return True if the publish was successful and false otherwise. - */ -bool PublishObjectNotification(RedisModuleCtx *ctx, - RedisModuleString *client_id, - RedisModuleString *object_id, - RedisModuleString *data_size, - RedisModuleKey *key) { - flatbuffers::FlatBufferBuilder fbb; - - long long data_size_value; - if (RedisModule_StringToLongLong(data_size, &data_size_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "data_size must be integer"); - } - - std::vector> manager_ids; - CHECK_ERROR( - RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - /* Loop over the managers in the object table for this object ID. */ - do { - RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL); - manager_ids.push_back(RedisStringToFlatbuf(fbb, curr)); - } while (RedisModule_ZsetRangeNext(key)); - - auto message = CreateSubscribeToNotificationsReply( - fbb, RedisStringToFlatbuf(fbb, object_id), data_size_value, - fbb.CreateVector(manager_ids)); - fbb.Finish(message); - - /* Publish the notification to the clients notification channel. - * TODO(rkn): These notifications could be batched together. */ - RedisModuleString *channel_name = - RedisString_Format(ctx, "%s%S", OBJECT_CHANNEL_PREFIX, client_id); - - RedisModuleString *payload = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, payload); - if (reply == NULL) { - return false; - } - return true; -} - -// NOTE(pcmoritz): This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -int PublishTaskTableAdd(RedisModuleCtx *ctx, - RedisModuleString *id, - RedisModuleString *data) { - const char *buf = RedisModule_StringPtrLen(data, NULL); - auto message = flatbuffers::GetRoot(buf); - RAY_CHECK(message != nullptr); - - if (message->scheduling_state() == SchedulingState::WAITING || - message->scheduling_state() == SchedulingState::SCHEDULED) { - /* Build the PUBLISH topic and message for task table subscribers. The - * topic - * is a string in the format "TASK_PREFIX::". - * The - * message is a serialized SubscribeToTasksReply flatbuffer object. */ - std::string state = - std::to_string(static_cast(message->scheduling_state())); - RedisModuleString *publish_topic = RedisString_Format( - ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(), - sizeof(DBClientID), state.c_str()); - - /* Construct the flatbuffers object for the payload. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto msg = - CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, id), - static_cast(message->scheduling_state()), - fbb.CreateString(message->scheduler_id()), - fbb.CreateString(message->execution_dependencies()), - fbb.CreateString(message->task_info()), - message->spillback_count(), true /* not used */); - fbb.Finish(msg); - - RedisModuleString *publish_message = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); - - /* See how many clients received this publish. */ - long long num_clients = RedisModule_CallReplyInteger(reply); - RAY_CHECK(num_clients <= 1) << "Published to " << num_clients - << " clients."; - } - return RedisModule_ReplyWithSimpleString(ctx, "OK"); -} - -/// Publish a notification for a new entry at a key. This publishes a -/// notification to all subscribers of the table, as well as every client that -/// has requested notifications for this key. -/// -/// \param pubsub_channel_str The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key that the notification is about. -/// \param data The data to publish. -/// \return OK if there is no error during a publish. -int PublishTableAdd(RedisModuleCtx *ctx, - RedisModuleString *pubsub_channel_str, - RedisModuleString *id, - RedisModuleString *data) { - // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - - // Write the data back to any subscribers that are listening to all table - // notifications. - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, - fbb.GetBufferPointer(), fbb.GetSize()); - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - - // Publish the data to any clients who requested notifications on this key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. - CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( - notification_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - for (; !RedisModule_ZsetRangeEndReached(notification_key); - RedisModule_ZsetRangeNext(notification_key)) { - RedisModuleString *client_channel = - RedisModule_ZsetRangeCurrentElement(notification_key, NULL); - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - fbb.GetBufferPointer(), fbb.GetSize()); - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - } - } - return RedisModule_ReplyWithSimpleString(ctx, "OK"); -} - -// RAY.TABLE_ADD: -// TableAdd_RedisCommand: the actual command handler. -// (helper) TableAdd_DoWrite: performs the write to redis state. -// (helper) TableAdd_DoPublish: performs a publish after the write. -// ChainTableAdd_RedisCommand: the same command, chain-enabled. - -int TableAdd_DoWrite(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc, - RedisModuleString **mutated_key_str) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - - RedisModuleKey *key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, - mutated_key_str); - RedisModule_StringSet(key, data); - return REDISMODULE_OK; -} - -int TableAdd_DoPublish(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - - TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str); - - if (pubsub_channel == TablePubsub::TASK) { - // Publish the task to its subscribers. - // TODO(swang): This is only necessary for legacy Ray and should be removed - // once we switch to using the new GCS API for the task table. - return PublishTaskTableAdd(ctx, id, data); - } else if (pubsub_channel != TablePubsub::NO_PUBLISH) { - // All other pubsub channels write the data back directly onto the channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); - } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); - } -} - -/// Add an entry at a key. This overwrites any existing data at the key. -/// Publishes a notification about the update to all subscribers, if a pubsub -/// channel is provided. -/// -/// This is called from a client with the command: -/// -/// RAY.TABLE_ADD -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to set. -/// \param data The data to insert at the key. -/// \return The current value at the key, or OK if there is no value. -int TableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - TableAdd_DoWrite(ctx, argv, argc, /*mutated_key_str=*/nullptr); - return TableAdd_DoPublish(ctx, argv, argc); -} - -#if RAY_USE_NEW_GCS -int ChainTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - return module.ChainReplicate(ctx, argv, argc, /*node_func=*/TableAdd_DoWrite, - /*tail_func=*/TableAdd_DoPublish); -} -#endif - -int TableAppend_DoWrite(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc, - RedisModuleString **mutated_key_str) { - if (argc < 5 || argc > 6) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - RedisModuleString *index_str = nullptr; - if (argc == 6) { - index_str = argv[5]; - } - - // Set the keys in the table. - RedisModuleKey *key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, - mutated_key_str); - // Determine the index at which the data should be appended. If no index is - // requested, then is the current length of the log. - size_t index = RedisModule_ValueLength(key); - if (index_str != nullptr) { - // Parse the requested index. - long long requested_index; - RAY_CHECK(RedisModule_StringToLongLong(index_str, &requested_index) == - REDISMODULE_OK); - RAY_CHECK(requested_index >= 0); - index = static_cast(requested_index); - } - // Only perform the append if the requested index matches the current length - // of the log, or if no index was requested. - if (index == RedisModule_ValueLength(key)) { - // The requested index matches the current length of the log or no index - // was requested. Perform the append. - int flags = REDISMODULE_ZADD_NX; - RedisModule_ZsetAdd(key, index, data, &flags); - // Check that we actually add a new entry during the append. This is only - // necessary since we implement the log with a sorted set, so all entries - // must be unique, or else we will have gaps in the log. - // TODO(rkn): We need to get rid of this uniqueness requirement. We can - // easily have multiple log events with the same message. - RAY_CHECK(flags == REDISMODULE_ZADD_ADDED) << "Appended a duplicate entry"; - return REDISMODULE_OK; - } else { - // The requested index did not match the current length of the log. Return - // an error message as a string. - static const char *reply = "ERR entry exists"; - RedisModule_ReplyWithStringBuffer(ctx, reply, strlen(reply)); - return REDISMODULE_ERR; - } -} - -int TableAppend_DoPublish(RedisModuleCtx *ctx, - RedisModuleString **argv, - int /*argc*/) { - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - // Publish a message on the requested pubsub channel if necessary. - TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str); - if (pubsub_channel != TablePubsub::NO_PUBLISH) { - // All other pubsub channels write the data back directly onto the - // channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); - } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); - } -} - -/// Append an entry to the log stored at a key. Publishes a notification about -/// the update to all subscribers, if a pubsub channel is provided. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_APPEND -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to append to. -/// \param data The data to append to the key. -/// \param index If this is set, then the data must be appended at this index. -/// If the current log is shorter or longer than the requested index, -/// then the append will fail and an error message will be returned as a -/// string. -/// \return OK if the append succeeds, or an error message string if the append -/// fails. -int TableAppend_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - const int status = TableAppend_DoWrite(ctx, argv, argc, - /*mutated_key_str=*/nullptr); - if (status) { - return status; - } - return TableAppend_DoPublish(ctx, argv, argc); -} - -#if RAY_USE_NEW_GCS -int ChainTableAppend_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - return module.ChainReplicate(ctx, argv, argc, - /*node_func=*/TableAppend_DoWrite, - /*tail_func=*/TableAppend_DoPublish); -} -#endif - -/// A helper function to create and finish a GcsTableEntry, based on the -/// current value or values at the given key. -void TableEntryToFlatbuf(RedisModuleKey *table_key, - RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { - auto key_type = RedisModule_KeyType(table_key); - switch (key_type) { - case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. - size_t data_len = 0; - char *data_buf = - RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(&data, 1)); - fbb.Finish(message); - } break; - case REDISMODULE_KEYTYPE_ZSET: { - // Build the flatbuffer from the set of log entries. - RAY_CHECK(RedisModule_ZsetFirstInScoreRange( - table_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1) == REDISMODULE_OK); - std::vector> data; - for (; !RedisModule_ZsetRangeEndReached(table_key); - RedisModule_ZsetRangeNext(table_key)) { - data.push_back(RedisStringToFlatbuf( - fbb, RedisModule_ZsetRangeCurrentElement(table_key, NULL))); - } - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(data)); - fbb.Finish(message); - } break; - case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsTableEntry( - fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector( - std::vector>())); - fbb.Finish(message); - } break; - default: - RAY_LOG(FATAL) << "Invalid Redis type during lookup: " << key_type; - } -} - -/// Lookup the current value or values at a key. Returns the current value or -/// values at the key. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_LOOKUP -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. This field is unused for lookups. -/// \param id The ID of the key to lookup. -/// \return nil if the key is empty, the current value if the key type is a -/// string, or an array of the current values if the key type is a set. -int TableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 4) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - - // Lookup the data at the key. - RedisModuleKey *table_key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ); - if (table_key == nullptr) { - RedisModule_ReplyWithNull(ctx); - } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - TableEntryToFlatbuf(table_key, id, fbb); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); - } - return REDISMODULE_OK; -} - -/// Request notifications for changes to a key. Returns the current value or -/// values at the key. Notifications will be sent to the requesting client for -/// every subsequent TABLE_ADD to the key. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_REQUEST_NOTIFICATIONS -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to publish notifications for. -/// \param client_id The ID of the client that is being notified. -/// \return nil if the key is empty, the current value if the key type is a -/// string, or an array of the current values if the key type is a set. -int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *client_id = argv[4]; - RedisModuleString *client_channel = - FormatPubsubChannel(ctx, pubsub_channel_str, client_id); - - // Add this client to the set of clients that should be notified when there - // are changes to the key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - CHECK_ERROR(RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL), - "ZsetAdd failed."); - - // Lookup the current value at the key. - RedisModuleKey *table_key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ); - // Publish the current value at the key to the client that is requesting - // notifications. An empty notification will be published if the key is - // empty. - flatbuffers::FlatBufferBuilder fbb; - TableEntryToFlatbuf(table_key, id, fbb); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); - - return RedisModule_ReplyWithNull(ctx); -} - -/// Cancel notifications for changes to a key. The client will no longer -/// receive notifications for this key. This does not check if the client -/// first requested notifications before canceling them. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_CANCEL_NOTIFICATIONS -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. If publishing to a specific client, -/// then the channel name should be :. -/// \param id The ID of the key to publish notifications for. -/// \param client_id The ID of the client to cancel notifications for. -/// \return OK. -int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *client_id = argv[4]; - RedisModuleString *client_channel = - FormatPubsubChannel(ctx, pubsub_channel_str, client_id); - - // Remove this client from the set of clients that should be notified when - // there are changes to the key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - RAY_CHECK(RedisModule_ZsetRem(notification_key, client_channel, NULL) == - REDISMODULE_OK); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -bool is_nil(const std::string &data) { - RAY_CHECK(data.size() == kUniqueIDSize); - const uint8_t *d = reinterpret_cast(data.data()); - for (int i = 0; i < kUniqueIDSize; ++i) { - if (d[i] != 255) { - return false; - } - } - return true; -} - -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key = OpenPrefixedKey(ctx, prefix_str, id, - REDISMODULE_READ | REDISMODULE_WRITE); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = flatbuffers::GetMutableRoot( - reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - if (!is_nil(update->test_scheduler_id()->str())) { - do_update = - do_update && - update->test_scheduler_id()->str() == data->scheduler_id()->str(); - } - - if (do_update) { - RAY_CHECK(data->mutate_scheduling_state(update->update_state())); - } - RAY_CHECK(data->mutate_updated(do_update)); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - -/** - * Add a new entry to the object table or update an existing one. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_ADD - * - * @param object_id A string representing the object ID. - * @param data_size An integer which is the object size in bytes. - * @param hash_string A string which is a hash of the object. - * @param manager A string which represents the manager ID of the plasma manager - * that has the object. - * @return OK if the operation was successful. If the same object_id is already - * present with a different hash value, the entry is still added, but - * an error with string "hash mismatch" is returned. - */ -int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *object_id = argv[1]; - RedisModuleString *data_size = argv[2]; - RedisModuleString *new_hash = argv[3]; - RedisModuleString *manager = argv[4]; - - long long data_size_value; - if (RedisModule_StringToLongLong(data_size, &data_size_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "data_size must be integer"); - } - - /* Set the fields in the object info table. */ - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - - /* Check if this object was already registered and if the hashes agree. */ - bool hash_mismatch = false; - if (RedisModule_KeyType(key) != REDISMODULE_KEYTYPE_EMPTY) { - RedisModuleString *existing_hash; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "hash", &existing_hash, - NULL); - /* The existing hash may be NULL even if the key is present because a call - * to RAY.RESULT_TABLE_ADD may have already created the key. */ - if (existing_hash != NULL) { - /* Check whether the new hash value matches the old one. If not, we will - * later return the "hash mismatch" error. */ - hash_mismatch = (RedisModule_StringCompare(existing_hash, new_hash) != 0); - } - } - - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "hash", new_hash, NULL); - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "data_size", data_size, - NULL); - - /* Add the location in the object location table. */ - RedisModuleKey *table_key; - table_key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - - /* Sets are not implemented yet, so we use ZSETs instead. */ - RedisModule_ZsetAdd(table_key, 0.0, manager, NULL); - - RedisModuleString *bcast_client_str = - RedisModule_CreateString(ctx, OBJECT_BCAST, strlen(OBJECT_BCAST)); - bool success = PublishObjectNotification(ctx, bcast_client_str, object_id, - data_size, table_key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH BCAST unsuccessful"); - } - - /* Get the zset of clients that requested a notification about the - * availability of this object. */ - RedisModuleKey *object_notification_key = - OpenPrefixedKey(ctx, OBJECT_NOTIFICATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - /* If the zset exists, initialize the key to iterate over the zset. */ - if (RedisModule_KeyType(object_notification_key) != - REDISMODULE_KEYTYPE_EMPTY) { - CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( - object_notification_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - /* Iterate over the list of clients that requested notifiations about the - * availability of this object, and publish notifications to their object - * notification channels. */ - - do { - RedisModuleString *client_id = - RedisModule_ZsetRangeCurrentElement(object_notification_key, NULL); - /* TODO(rkn): Some computation could be saved by batching the string - * constructions in the multiple calls to PublishObjectNotification - * together. */ - bool success = PublishObjectNotification(ctx, client_id, object_id, - data_size, table_key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - } while (RedisModule_ZsetRangeNext(object_notification_key)); - /* Now that the clients have been notified, remove the zset of clients - * waiting for notifications. */ - CHECK_ERROR(RedisModule_DeleteKey(object_notification_key), - "Unable to delete zset key."); - } - - if (hash_mismatch) { - return RedisModule_ReplyWithError(ctx, "hash mismatch"); - } else { - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; - } -} - -/** - * Remove a manager from a location entry in the object table. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_REMOVE - * - * @param object_id A string representing the object ID. - * @param manager A string which represents the manager ID of the plasma manager - * to remove. - * @return OK if the operation was successful or an error with string - * "object not found" if the entry for the object_id doesn't exist. The - * operation is counted as a success if the manager was already not in - * the entry. - */ -int ObjectTableRemove_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 3) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *object_id = argv[1]; - RedisModuleString *manager = argv[2]; - - /* Remove the location from the object location table. */ - RedisModuleKey *table_key; - table_key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(table_key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "object not found"); - } - - RedisModule_ZsetRem(table_key, manager, NULL); - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Request notifications about the presence of some object IDs. This command - * takes a list of object IDs. For each object ID, the reply will be the list - * of plasma managers that contain the object. If the list of plasma managers - * is currently nonempty, then the reply will happen immediately. Else, the - * reply will come later, on the first invocation of `RAY.OBJECT_TABLE_ADD` - * following this call. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS - * ... - * - * @param client_id The ID of the client that is requesting the notifications. - * @param object_id(n) The ID of the nth object ID that is passed to this - * command. This command can take any number of object IDs. - * @return OK if the operation was successful. - */ -int ObjectTableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 3) { - return RedisModule_WrongArity(ctx); - } - - /* The first argument is the client ID. The other arguments are object IDs. */ - RedisModuleString *client_id = argv[1]; - - /* Loop over the object ID arguments to this command. */ - for (int i = 2; i < argc; ++i) { - RedisModuleString *object_id = argv[i]; - RedisModuleKey *key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, - object_id, REDISMODULE_READ); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY || - RedisModule_ValueLength(key) == 0) { - /* This object ID is currently not present, so make a note that this - * client should be notified when this object ID becomes available. */ - RedisModuleKey *object_notification_key = - OpenPrefixedKey(ctx, OBJECT_NOTIFICATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - /* Add this client to the list of clients that will be notified when this - * object becomes available. */ - CHECK_ERROR( - RedisModule_ZsetAdd(object_notification_key, 0.0, client_id, NULL), - "ZsetAdd failed."); - } else { - /* Publish a notification to the client's object notification channel. */ - /* Extract the data_size first. */ - RedisModuleKey *object_info_key; - object_info_key = - OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "requested object not found"); - } - RedisModuleString *existing_data_size; - RedisModule_HashGet(object_info_key, REDISMODULE_HASH_CFIELDS, - "data_size", &existing_data_size, NULL); - if (existing_data_size == NULL) { - return RedisModule_ReplyWithError(ctx, - "no data_size field in object info"); - } - - bool success = PublishObjectNotification(ctx, client_id, object_id, - existing_data_size, key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - } - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -int ObjectInfoSubscribe_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - REDISMODULE_NOT_USED(argv); - REDISMODULE_NOT_USED(argc); - return REDISMODULE_OK; -} - -/** - * Add a new entry to the result table or update an existing one. - * - * This is called from a client with the command: - * - * RAY.RESULT_TABLE_ADD - * - * @param object_id A string representing the object ID. - * @param task_id A string representing the task ID of the task that produced - * the object. - * @param is_put An integer that is 1 if the object was created through ray.put - * and 0 if created by return value. - * @return OK if the operation was successful. - */ -int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 4) { - return RedisModule_WrongArity(ctx); - } - - /* Set the task ID under field "task" in the object info table. */ - RedisModuleString *object_id = argv[1]; - RedisModuleString *task_id = argv[2]; - RedisModuleString *is_put = argv[3]; - - /* Check to make sure the is_put field was a 0 or a 1. */ - long long is_put_integer; - if ((RedisModule_StringToLongLong(is_put, &is_put_integer) != - REDISMODULE_OK) || - (is_put_integer != 0 && is_put_integer != 1)) { - return RedisModule_ReplyWithError( - ctx, "The is_put field must be either a 0 or a 1."); - } - - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_WRITE); - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "task", task_id, "is_put", - is_put, NULL); - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - - return REDISMODULE_OK; -} - -/** - * Reply with information about a task ID. This is used by - * RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET. - * - * @param ctx The Redis context. - * @param task_id The task ID of the task to reply about. - * @param updated A boolean representing whether the task was updated during - * this operation. This field is only used for - * RAY.TASK_TABLE_TEST_AND_UPDATE operations. - * @return NIL if the task ID is not in the task table. An error if the task ID - * is in the task table but the appropriate fields are not there, and - * an array of the task scheduling state, the local scheduler ID, and - * the task spec for the task otherwise. - */ -int ReplyWithTask(RedisModuleCtx *ctx, - RedisModuleString *task_id, - bool updated) { - RedisModuleKey *key = - OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_READ); - - if (RedisModule_KeyType(key) != REDISMODULE_KEYTYPE_EMPTY) { - /* If the key exists, look up the fields and return them in an array. */ - RedisModuleString *state = NULL; - RedisModuleString *local_scheduler_id = NULL; - RedisModuleString *execution_dependencies = NULL; - RedisModuleString *task_spec = NULL; - RedisModuleString *spillback_count = NULL; - RedisModule_HashGet( - key, REDISMODULE_HASH_CFIELDS, "state", &state, "local_scheduler_id", - &local_scheduler_id, "execution_dependencies", &execution_dependencies, - "TaskSpec", &task_spec, "spillback_count", &spillback_count, NULL); - if (state == NULL || local_scheduler_id == NULL || - execution_dependencies == NULL || task_spec == NULL || - spillback_count == NULL) { - /* We must have either all fields or no fields. */ - return RedisModule_ReplyWithError( - ctx, "Missing fields in the task table entry"); - } - - long long state_integer; - long long spillback_count_val; - if ((RedisModule_StringToLongLong(state, &state_integer) != - REDISMODULE_OK) || - (state_integer < 0) || - (RedisModule_StringToLongLong(spillback_count, &spillback_count_val) != - REDISMODULE_OK) || - (spillback_count_val < 0)) { - return RedisModule_ReplyWithError( - ctx, "Found invalid scheduling state or spillback count."); - } - - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateTaskReply( - fbb, RedisStringToFlatbuf(fbb, task_id), state_integer, - RedisStringToFlatbuf(fbb, local_scheduler_id), - RedisStringToFlatbuf(fbb, execution_dependencies), - RedisStringToFlatbuf(fbb, task_spec), spillback_count_val, updated); - fbb.Finish(message); - - RedisModuleString *reply = RedisModule_CreateString( - ctx, (char *) fbb.GetBufferPointer(), fbb.GetSize()); - RedisModule_ReplyWithString(ctx, reply); - } else { - /* If the key does not exist, return nil. */ - RedisModule_ReplyWithNull(ctx); - } - - return REDISMODULE_OK; -} - -/** - * Lookup an entry in the result table. - * - * This is called from a client with the command: - * - * RAY.RESULT_TABLE_LOOKUP - * - * @param object_id A string representing the object ID. - * @return NIL if the object ID is not in the result table. Otherwise, this - * returns a ResultTableReply flatbuffer. - */ -int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - /* Get the task ID under field "task" in the object info table. */ - RedisModuleString *object_id = argv[1]; - - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ); - - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithNull(ctx); - } - - RedisModuleString *task_id; - RedisModuleString *is_put; - RedisModuleString *data_size; - RedisModuleString *hash; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task", &task_id, "is_put", - &is_put, "data_size", &data_size, "hash", &hash, NULL); - - if (task_id == NULL || is_put == NULL) { - return RedisModule_ReplyWithNull(ctx); - } - - /* Check to make sure the is_put field was a 0 or a 1. */ - long long is_put_integer; - if (RedisModule_StringToLongLong(is_put, &is_put_integer) != REDISMODULE_OK || - (is_put_integer != 0 && is_put_integer != 1)) { - return RedisModule_ReplyWithError( - ctx, "The is_put field must be either a 0 or a 1."); - } - - /* Make and return the flatbuffer reply. */ - flatbuffers::FlatBufferBuilder fbb; - long long data_size_value; - - if (data_size == NULL) { - data_size_value = -1; - } else { - RedisModule_StringToLongLong(data_size, &data_size_value); - RAY_CHECK(RedisModule_StringToLongLong(data_size, &data_size_value) == - REDISMODULE_OK); - } - - flatbuffers::Offset hash_str; - if (hash == NULL) { - hash_str = fbb.CreateString("", strlen("")); - } else { - hash_str = RedisStringToFlatbuf(fbb, hash); - } - - flatbuffers::Offset message = - CreateResultTableReply(fbb, RedisStringToFlatbuf(fbb, task_id), - bool(is_put_integer), data_size_value, hash_str); - - fbb.Finish(message); - RedisModuleString *reply = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - RedisModule_ReplyWithString(ctx, reply); - - return REDISMODULE_OK; -} - -int TaskTableWrite(RedisModuleCtx *ctx, - RedisModuleString *task_id, - RedisModuleString *state, - RedisModuleString *local_scheduler_id, - RedisModuleString *execution_dependencies, - RedisModuleString *spillback_count, - RedisModuleString *task_spec) { - /* Extract the scheduling state. */ - long long state_value; - if (RedisModule_StringToLongLong(state, &state_value) != REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "scheduling state must be integer"); - } - - long long spillback_count_value; - if (RedisModule_StringToLongLong(spillback_count, &spillback_count_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "spillback count must be integer"); - } - /* Add the task to the task table. If no spec was provided, get the existing - * spec out of the task table so we can publish it. */ - RedisModuleString *existing_task_spec = NULL; - RedisModuleKey *key = - OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_WRITE); - if (task_spec == NULL) { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, - "local_scheduler_id", local_scheduler_id, - "execution_dependencies", execution_dependencies, - "spillback_count", spillback_count, NULL); - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "TaskSpec", - &existing_task_spec, NULL); - if (existing_task_spec == NULL) { - return RedisModule_ReplyWithError( - ctx, "Cannot update a task that doesn't exist yet"); - } - } else { - RedisModule_HashSet( - key, REDISMODULE_HASH_CFIELDS, "state", state, "local_scheduler_id", - local_scheduler_id, "execution_dependencies", execution_dependencies, - "TaskSpec", task_spec, "spillback_count", spillback_count, NULL); - } - - if (static_cast(state_value) == TaskStatus::WAITING || - static_cast(state_value) == TaskStatus::SCHEDULED) { - /* Build the PUBLISH topic and message for task table subscribers. The - * topic is a string in the format - * "TASK_PREFIX::". The message is a serialized - * SubscribeToTasksReply flatbuffer object. */ - RedisModuleString *publish_topic = RedisString_Format( - ctx, "%s%S:%S", TASK_PREFIX, local_scheduler_id, state); - - /* Construct the flatbuffers object for the payload. */ - flatbuffers::FlatBufferBuilder fbb; - /* Use the old task spec if the current one is NULL. */ - RedisModuleString *task_spec_to_use; - if (task_spec != NULL) { - task_spec_to_use = task_spec; - } else { - task_spec_to_use = existing_task_spec; - } - /* Create the flatbuffers message. */ - auto message = CreateTaskReply( - fbb, RedisStringToFlatbuf(fbb, task_id), state_value, - RedisStringToFlatbuf(fbb, local_scheduler_id), - RedisStringToFlatbuf(fbb, execution_dependencies), - RedisStringToFlatbuf(fbb, task_spec_to_use), spillback_count_value, - true); // The updated field is not used. - fbb.Finish(message); - - RedisModuleString *publish_message = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); - - /* See how many clients received this publish. */ - long long num_clients = RedisModule_CallReplyInteger(reply); - RAY_CHECK(num_clients <= 1) << "Published to " << num_clients - << " clients."; - - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - if (num_clients == 0) { - /* This reply will be received by redis_task_table_update_callback or - * redis_task_table_add_task_callback in redis.cc, which will then reissue - * the command. */ - return RedisModule_ReplyWithError(ctx, - "No subscribers received message."); - } - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - - return REDISMODULE_OK; -} - -/** - * Add a new entry to the task table. This will overwrite any existing entry - * with the same task ID. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_ADD - * - * - * @param task_id A string that is the ID of the task. - * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). - * @param local_scheduler_id A string that is the ray client ID of the - * associated local scheduler, if any. - * @param execution_dependencies A string that is the list of execution - * dependencies. - * @param task_spec A string that is the specification of the task, which can - * be cast to a `task_spec`. - * @return OK if the operation was successful. - */ -int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 7) { - return RedisModule_WrongArity(ctx); - } - - return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5], - argv[6]); -} - -/** - * Update an entry in the task table. This does not update the task - * specification in the table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_UPDATE - * - * - * @param task_id A string that is the ID of the task. - * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). - * @param ray_client_id A string that is the ray client ID of the associated - * local scheduler, if any. - * @param execution_dependencies A string that is the list of execution - * dependencies. - * @return OK if the operation was successful. - */ -int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 6) { - return RedisModule_WrongArity(ctx); - } - - return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5], NULL); -} - -/** - * Test and update an entry in the task table if the current value matches the - * test value bitmask. This does not update the task specification in the - * table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_TEST_AND_UPDATE - * - * - * @param task_id A string that is the ID of the task. - * @param test_state_bitmask A string that is the test bitmask for the - * scheduling state. The update happens if and only if the current - * scheduling state AND-ed with the bitmask is greater than 0. - * @param state A string that is the scheduling state (a scheduling_state enum - * instance) to update the task entry with. - * @param ray_client_id A string that is the ray client ID of the associated - * local scheduler, if any, to update the task entry with. - * @param test_local_scheduler_id A string to test the local scheduler ID. If - * provided, and if the current local scheduler ID does not match it, - * then the update does not happen. - * @return Returns the task entry as a TaskReply. The reply will reflect the - * update, if it happened. - */ -int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 5 || argc > 6) { - return RedisModule_WrongArity(ctx); - } - /* If a sixth argument was provided, then we should also test the current - * local scheduler ID. */ - bool test_local_scheduler = (argc == 6); - - RedisModuleString *task_id = argv[1]; - RedisModuleString *test_state = argv[2]; - RedisModuleString *update_state = argv[3]; - RedisModuleString *local_scheduler_id = argv[4]; - - RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, task_id, - REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithNull(ctx); - } - - /* If the key exists, look up the fields and return them in an array. */ - RedisModuleString *current_state = NULL; - RedisModuleString *current_local_scheduler_id = NULL; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", ¤t_state, - "local_scheduler_id", ¤t_local_scheduler_id, NULL); - - long long current_state_integer; - if (RedisModule_StringToLongLong(current_state, ¤t_state_integer) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "current_state must be integer"); - } - - if (current_state_integer < 0) { - return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state."); - } - long long test_state_bitmask; - int status = RedisModule_StringToLongLong(test_state, &test_state_bitmask); - if (status != REDISMODULE_OK) { - return RedisModule_ReplyWithError( - ctx, "Invalid test value for scheduling state"); - } - - bool update = false; - if (current_state_integer & test_state_bitmask) { - if (test_local_scheduler) { - /* A test local scheduler ID was provided. Test whether it is equal to - * the current local scheduler ID before performing the update. */ - RedisModuleString *test_local_scheduler_id = argv[5]; - if (RedisModule_StringCompare(current_local_scheduler_id, - test_local_scheduler_id) == 0) { - /* If the current local scheduler ID does matches the test ID, then - * perform the update. */ - update = true; - } - } else { - /* No test local scheduler ID was provided. Perform the update. */ - update = true; - } - } - - /* If the scheduling state and local scheduler ID tests passed, then perform - * the update. */ - if (update) { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", update_state, - "local_scheduler_id", local_scheduler_id, NULL); - } - - /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, task_id, update); -} - -/** - * Get an entry from the task table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_GET - * - * @param task_id A string of the task ID to look up. - * @return An array of strings representing the task fields in the following - * order: 1) (integer) scheduling state 2) (string) associated local - * scheduler ID, if any 3) (string) the task specification, which can be - * cast to a task_spec. If the task ID is not in the table, returns nil. - */ -int TaskTableGet_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, argv[1], false); -} - -extern "C" { - -/* This function must be present on each Redis module. It is used in order to - * register the commands into the Redis server. */ -int RedisModule_OnLoad(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - REDISMODULE_NOT_USED(argv); - REDISMODULE_NOT_USED(argc); - - if (RedisModule_Init(ctx, "ray", 1, REDISMODULE_APIVER_1) == - REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.connect", Connect_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.disconnect", Disconnect_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_add", TableAdd_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_append", - TableAppend_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_lookup", - TableLookup_RedisCommand, "readonly", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", - TableRequestNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_cancel_notifications", - TableCancelNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_lookup", - ObjectTableLookup_RedisCommand, "readonly", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_add", - ObjectTableAdd_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_remove", - ObjectTableRemove_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_request_notifications", - ObjectTableRequestNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_info_subscribe", - ObjectInfoSubscribe_RedisCommand, "pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.result_table_add", - ResultTableAdd_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.result_table_lookup", - ResultTableLookup_RedisCommand, "readonly", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_add", - TaskTableAddTask_RedisCommand, "write pubsub", - 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_update", - TaskTableUpdate_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_test_and_update", - TaskTableTestAndUpdate_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_get", - TaskTableGet_RedisCommand, "readonly", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - -#if RAY_USE_NEW_GCS - // Chain-enabled commands that depend on ray-project/credis. - if (RedisModule_CreateCommand(ctx, "ray.chain.table_add", - ChainTableAdd_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.chain.table_append", - ChainTableAppend_RedisCommand, "write pubsub", - 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } -#endif - - return REDISMODULE_OK; -} - -} /* extern "C" */ diff --git a/src/common/shims/windows/getopt.c b/src/common/shims/windows/getopt.c deleted file mode 100644 index d9c4ae583307f..0000000000000 --- a/src/common/shims/windows/getopt.c +++ /dev/null @@ -1,69 +0,0 @@ -/* http://stackoverflow.com/a/17195644/541686 */ - -#include -#include - -int opterr = 1, /* if error message should be printed */ - optind = 1, /* index into parent argv vector */ - optopt, /* character checked for validity */ - optreset; /* reset getopt */ -char *optarg; /* argument associated with option */ - -#define BADCH (int) '?' -#define BADARG (int) ':' -#define EMSG "" - -/* -* getopt -- -* Parse argc/argv argument vector. -*/ -int getopt(int nargc, char *const nargv[], const char *ostr) { - static char *place = EMSG; /* option letter processing */ - const char *oli; /* option letter list index */ - - if (optreset || !*place) { /* update scanning pointer */ - optreset = 0; - if (optind >= nargc || *(place = nargv[optind]) != '-') { - place = EMSG; - return (-1); - } - if (place[1] && *++place == '-') { /* found "--" */ - ++optind; - place = EMSG; - return (-1); - } - } /* option letter okay? */ - if ((optopt = (int) *place++) == (int) ':' || !(oli = strchr(ostr, optopt))) { - /* - * if the user didn't specify '-' as an option, - * assume it means -1. - */ - if (optopt == (int) '-') - return (-1); - if (!*place) - ++optind; - if (opterr && *ostr != ':') - (void) printf("illegal option -- %c\n", optopt); - return (BADCH); - } - if (*++oli != ':') { /* don't need argument */ - optarg = NULL; - if (!*place) - ++optind; - } else { /* need an argument */ - if (*place) /* no white space */ - optarg = place; - else if (nargc <= ++optind) { /* no arg */ - place = EMSG; - if (*ostr == ':') - return (BADARG); - if (opterr) - (void) printf("option requires an argument -- %c\n", optopt); - return (BADCH); - } else /* white space */ - optarg = nargv[optind]; - place = EMSG; - ++optind; - } - return (optopt); /* dump back option letter */ -} diff --git a/src/common/shims/windows/getopt.h b/src/common/shims/windows/getopt.h deleted file mode 100644 index 1870fb87f7930..0000000000000 --- a/src/common/shims/windows/getopt.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef GETOPT_H -#define GETOPT_H - -#endif /* GETOPT_H */ diff --git a/src/common/shims/windows/msg.c b/src/common/shims/windows/msg.c deleted file mode 100644 index 5142c1aadf2ed..0000000000000 --- a/src/common/shims/windows/msg.c +++ /dev/null @@ -1,208 +0,0 @@ -#include - -int socketpair(int domain, int type, int protocol, int sv[2]) { - if ((domain != AF_UNIX && domain != AF_INET) || type != SOCK_STREAM) { - return INVALID_SOCKET; - } - SOCKET sockets[2]; - int r = dumb_socketpair(sockets); - sv[0] = (int) sockets[0]; - sv[1] = (int) sockets[1]; - return r; -} - -#pragma comment(lib, "IPHlpAPI.lib") - -struct _MIB_TCPROW2 { - DWORD dwState, dwLocalAddr, dwLocalPort, dwRemoteAddr, dwRemotePort, - dwOwningPid; - enum _TCP_CONNECTION_OFFLOAD_STATE dwOffloadState; -}; - -struct _MIB_TCPTABLE2 { - DWORD dwNumEntries; - struct _MIB_TCPROW2 table[1]; -}; - -DECLSPEC_IMPORT ULONG WINAPI GetTcpTable2(struct _MIB_TCPTABLE2 *TcpTable, - PULONG SizePointer, - BOOL Order); - -static DWORD getsockpid(SOCKET client) { - /* http://stackoverflow.com/a/25431340 */ - DWORD pid = 0; - - struct sockaddr_in Server = {0}; - int ServerSize = sizeof(Server); - - struct sockaddr_in Client = {0}; - int ClientSize = sizeof(Client); - - if ((getsockname(client, (struct sockaddr *) &Server, &ServerSize) == 0) && - (getpeername(client, (struct sockaddr *) &Client, &ClientSize) == 0)) { - struct _MIB_TCPTABLE2 *TcpTable = NULL; - ULONG TcpTableSize = 0; - ULONG result; - do { - result = GetTcpTable2(TcpTable, &TcpTableSize, TRUE); - if (result != ERROR_INSUFFICIENT_BUFFER) { - break; - } - free(TcpTable); - TcpTable = (struct _MIB_TCPTABLE2 *) malloc(TcpTableSize); - } while (TcpTable != NULL); - - if (result == NO_ERROR) { - for (DWORD dw = 0; dw < TcpTable->dwNumEntries; ++dw) { - struct _MIB_TCPROW2 *row = &(TcpTable->table[dw]); - if ((row->dwState == 5 /* MIB_TCP_STATE_ESTAB */) && - (row->dwLocalAddr == Client.sin_addr.s_addr) && - ((row->dwLocalPort & 0xFFFF) == Client.sin_port) && - (row->dwRemoteAddr == Server.sin_addr.s_addr) && - ((row->dwRemotePort & 0xFFFF) == Server.sin_port)) { - pid = row->dwOwningPid; - break; - } - } - } - - free(TcpTable); - } - - return pid; -} - -ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags) { - ssize_t result = -1; - struct cmsghdr *header = CMSG_FIRSTHDR(msg); - if (header->cmsg_level == SOL_SOCKET && header->cmsg_type == SCM_RIGHTS) { - /* We're trying to send over a handle of some kind. - * We have to look up which process we're communicating with, - * open a handle to it, and then duplicate our handle into it. - * However, the first two steps cannot be done atomically. - * Therefore, this code HAS A RACE CONDITIONS and is therefore NOT SECURE. - * In the absense of a malicious actor, though, it is exceedingly unlikely - * that the child process closes AND that its process ID is reassigned - * to another existing process. - */ - struct msghdr const old_msg = *msg; - int *const pfd = (int *) CMSG_DATA(header); - msg->msg_control = NULL; - msg->msg_controllen = 0; - WSAPROTOCOL_INFO protocol_info = {0}; - BOOL const is_socket = !!FDAPI_GetSocketStatePtr(*pfd); - DWORD const target_pid = getsockpid(sockfd); - HANDLE target_process = NULL; - if (target_pid) { - if (!is_socket) { - /* This is a regular handle... fit it into the same struct */ - target_process = OpenProcess(PROCESS_DUP_HANDLE, FALSE, target_pid); - if (target_process) { - if (DuplicateHandle(GetCurrentProcess(), (HANDLE)(intptr_t) *pfd, - target_process, (HANDLE *) &protocol_info, 0, - TRUE, DUPLICATE_SAME_ACCESS)) { - result = 0; - } - } - } else { - /* This is a socket... */ - result = FDAPI_WSADuplicateSocket(*pfd, target_pid, &protocol_info); - } - } - if (result == 0) { - int const nbufs = msg->dwBufferCount + 1; - WSABUF *const bufs = - (struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs); - bufs[0].buf = (char *) &protocol_info; - bufs[0].len = sizeof(protocol_info); - memcpy(&bufs[1], msg->lpBuffers, - msg->dwBufferCount * sizeof(*msg->lpBuffers)); - DWORD nb; - msg->lpBuffers = bufs; - msg->dwBufferCount = nbufs; - GUID const wsaid_WSASendMsg = { - 0xa441e712, - 0x754f, - 0x43ca, - {0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d}}; - typedef INT PASCAL WSASendMsg_t( - SOCKET s, LPWSAMSG lpMsg, DWORD dwFlags, LPDWORD lpNumberOfBytesSent, - LPWSAOVERLAPPED lpOverlapped, - LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine); - WSASendMsg_t *WSASendMsg = NULL; - result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER, - &wsaid_WSASendMsg, sizeof(wsaid_WSASendMsg), - &WSASendMsg, sizeof(WSASendMsg), &nb, NULL, 0); - if (result == 0) { - result = (*WSASendMsg)(sockfd, msg, flags, &nb, NULL, NULL) == 0 - ? (ssize_t)(nb - sizeof(protocol_info)) - : 0; - } - } - if (result != 0 && target_process && !is_socket) { - /* we failed to send the handle, and it needs cleaning up! */ - HANDLE duplicated_back = NULL; - if (DuplicateHandle(target_process, *(HANDLE *) &protocol_info, - GetCurrentProcess(), &duplicated_back, 0, FALSE, - DUPLICATE_CLOSE_SOURCE)) { - CloseHandle(duplicated_back); - } - } - if (target_process) { - CloseHandle(target_process); - } - *msg = old_msg; - } - return result; -} - -ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags) { - int result = -1; - struct cmsghdr *header = CMSG_FIRSTHDR(msg); - if (msg->msg_controllen && - flags == 0 /* We can't send flags on Windows... */) { - struct msghdr const old_msg = *msg; - msg->msg_control = NULL; - msg->msg_controllen = 0; - WSAPROTOCOL_INFO protocol_info = {0}; - int const nbufs = msg->dwBufferCount + 1; - WSABUF *const bufs = - (struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs); - bufs[0].buf = (char *) &protocol_info; - bufs[0].len = sizeof(protocol_info); - memcpy(&bufs[1], msg->lpBuffers, - msg->dwBufferCount * sizeof(*msg->lpBuffers)); - typedef INT PASCAL WSARecvMsg_t( - SOCKET s, LPWSAMSG lpMsg, LPDWORD lpNumberOfBytesRecvd, - LPWSAOVERLAPPED lpOverlapped, - LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine); - WSARecvMsg_t *WSARecvMsg = NULL; - DWORD nb; - GUID const wsaid_WSARecvMsg = { - 0xf689d7c8, - 0x6f1f, - 0x436b, - {0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22}}; - result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER, - &wsaid_WSARecvMsg, sizeof(wsaid_WSARecvMsg), - &WSARecvMsg, sizeof(WSARecvMsg), &nb, NULL, 0); - if (result == 0) { - result = (*WSARecvMsg)(sockfd, msg, &nb, NULL, NULL) == 0 - ? (ssize_t)(nb - sizeof(protocol_info)) - : 0; - } - if (result == 0) { - int *const pfd = (int *) CMSG_DATA(header); - if (protocol_info.iSocketType == 0 && protocol_info.iProtocol == 0) { - *pfd = *(int *) &protocol_info; - } else { - *pfd = FDAPI_WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, &protocol_info, 0, 0); - } - header->cmsg_level = SOL_SOCKET; - header->cmsg_type = SCM_RIGHTS; - } - *msg = old_msg; - } - return result; -} diff --git a/src/common/shims/windows/netdb.h b/src/common/shims/windows/netdb.h deleted file mode 100644 index 5dace165919a2..0000000000000 --- a/src/common/shims/windows/netdb.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef NETDB_H -#define NETDB_H - -#endif /* NETDB_H */ diff --git a/src/common/shims/windows/netinet/in.h b/src/common/shims/windows/netinet/in.h deleted file mode 100644 index a60db3e05dd62..0000000000000 --- a/src/common/shims/windows/netinet/in.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef IN_H -#define IN_H - -#endif /* IN_H */ diff --git a/src/common/shims/windows/poll.h b/src/common/shims/windows/poll.h deleted file mode 100644 index 058e23adee645..0000000000000 --- a/src/common/shims/windows/poll.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef POLL_H -#define POLL_H - -#endif /* POLL_H */ diff --git a/src/common/shims/windows/socketpair.c b/src/common/shims/windows/socketpair.c deleted file mode 100644 index e9fc792c15a70..0000000000000 --- a/src/common/shims/windows/socketpair.c +++ /dev/null @@ -1,150 +0,0 @@ -/* socketpair.c -Copyright 2007, 2010 by Nathan C. Myers -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - The name of the author must not be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* Changes: - * 2014-02-12: merge David Woodhouse, Ger Hobbelt improvements - * git.infradead.org/users/dwmw2/openconnect.git/commitdiff/bdeefa54 - * github.com/GerHobbelt/selectable-socketpair - * always init the socks[] to -1/INVALID_SOCKET on error, both on Win32/64 - * and UNIX/other platforms - * 2013-07-18: Change to BSD 3-clause license - * 2010-03-31: - * set addr to 127.0.0.1 because win32 getsockname does not always set it. - * 2010-02-25: - * set SO_REUSEADDR option to avoid leaking some windows resource. - * Windows System Error 10049, "Event ID 4226 TCP/IP has reached - * the security limit imposed on the number of concurrent TCP connect - * attempts." Bleah. - * 2007-04-25: - * preserve value of WSAGetLastError() on all error returns. - * 2007-04-22: (Thanks to Matthew Gregan ) - * s/EINVAL/WSAEINVAL/ fix trivial compile failure - * s/socket/WSASocket/ enable creation of sockets suitable as stdin/stdout - * of a child process. - * add argument make_overlapped - */ - -#include - -#ifdef WIN32 -#include /* socklen_t, et al (MSVC20xx) */ -#include -#include -#else -#ifdef _WIN32 -#include -#include -#endif -#include -#include -#include -#endif - -#ifdef WIN32 - -/* dumb_socketpair: - * If make_overlapped is nonzero, both sockets created will be usable for - * "overlapped" operations via WSASend etc. If make_overlapped is zero, - * socks[0] (only) will be usable with regular ReadFile etc., and thus - * suitable for use as stdin or stdout of a child process. Note that the - * sockets must be closed with closesocket() regardless. - */ - -int dumb_socketpair(SOCKET socks[2]) { - union { - struct sockaddr_in inaddr; - struct sockaddr addr; - } a; - SOCKET listener; - int e; - socklen_t addrlen = sizeof(a.inaddr); - int reuse = 1; - - if (socks == 0) { - return SOCKET_ERROR; - } - socks[0] = socks[1] = -1; - - listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (listener == -1) - return SOCKET_ERROR; - - memset(&a, 0, sizeof(a)); - a.inaddr.sin_family = AF_INET; - a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - a.inaddr.sin_port = 0; - - for (;;) { - if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, (char *) &reuse, - (socklen_t) sizeof(reuse)) == -1) - break; - if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR) - break; - - memset(&a, 0, sizeof(a)); - if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR) - break; - // win32 getsockname may only set the port number, p=0.0005. - // ( http://msdn.microsoft.com/library/ms738543.aspx ): - a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - a.inaddr.sin_family = AF_INET; - - if (listen(listener, 1) == SOCKET_ERROR) - break; - - socks[0] = FDAPI_WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, 0); - if (socks[0] == -1) - break; - if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR) - break; - - socks[1] = accept(listener, NULL, NULL); - if (socks[1] == -1) - break; - - FDAPI_close(listener); - return 0; - } - - FDAPI_close(listener); - FDAPI_close(socks[0]); - FDAPI_close(socks[1]); - socks[0] = socks[1] = -1; - return SOCKET_ERROR; -} -#else -int dumb_socketpair(int socks[2], int dummy) { - if (socks == 0) { - errno = EINVAL; - return -1; - } - dummy = socketpair(AF_LOCAL, SOCK_STREAM, 0, socks); - if (dummy) - socks[0] = socks[1] = -1; - return dummy; -} -#endif diff --git a/src/common/shims/windows/strings.h b/src/common/shims/windows/strings.h deleted file mode 100644 index e264061c4e6ef..0000000000000 --- a/src/common/shims/windows/strings.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef STRINGS_H -#define STRINGS_H - -#endif /* STRINGS_H */ diff --git a/src/common/shims/windows/sys/ioctl.h b/src/common/shims/windows/sys/ioctl.h deleted file mode 100644 index 00f7a55ed77dc..0000000000000 --- a/src/common/shims/windows/sys/ioctl.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef IOCTL_H -#define IOCTL_H - -#endif /* IOCTL_H */ diff --git a/src/common/shims/windows/sys/mman.h b/src/common/shims/windows/sys/mman.h deleted file mode 100644 index a12df75fc7eac..0000000000000 --- a/src/common/shims/windows/sys/mman.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef MMAN_H -#define MMAN_H - -#include - -#define MAP_SHARED 0x0010 /* share changes */ -#define MAP_FAILED ((void *) -1) -#define PROT_READ 0x04 /* pages can be read */ -#define PROT_WRITE 0x02 /* pages can be written */ -#define PROT_EXEC 0x01 /* pages can be executed */ - -static void *mmap(void *addr, - size_t len, - int prot, - int flags, - int fd, - off_t off) { - void *result = (void *) (-1); - if (!addr && prot == MAP_SHARED) { - /* HACK: we're assuming handle sizes can't exceed 32 bits, which is wrong... - * but works for now. */ - void *ptr = MapViewOfFile((HANDLE)(intptr_t) fd, FILE_MAP_ALL_ACCESS, - (DWORD)(off >> (CHAR_BIT * sizeof(DWORD))), - (DWORD) off, (SIZE_T) len); - if (ptr) { - result = ptr; - } - } - return result; -} -static int munmap(void *addr, size_t length) { - (void) length; - return UnmapViewOfFile(addr) ? 0 : -1; -} - -#endif /* MMAN_H */ diff --git a/src/common/shims/windows/sys/select.h b/src/common/shims/windows/sys/select.h deleted file mode 100644 index 8aef7950e3993..0000000000000 --- a/src/common/shims/windows/sys/select.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef SELECT_H -#define SELECT_H - -#endif /* SELECT_H */ diff --git a/src/common/shims/windows/sys/socket.h b/src/common/shims/windows/sys/socket.h deleted file mode 100644 index ba9d656bb96d9..0000000000000 --- a/src/common/shims/windows/sys/socket.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SOCKET_H -#define SOCKET_H - -typedef unsigned short sa_family_t; - -#include "../../src/Win32_Interop/Win32_FDAPI.h" -#include "../../src/Win32_Interop/Win32_APIs.h" - -#define cmsghdr _WSACMSGHDR -#undef CMSG_DATA -#define CMSG_DATA WSA_CMSG_DATA -#define CMSG_SPACE WSA_CMSG_SPACE -#define CMSG_FIRSTHDR WSA_CMSG_FIRSTHDR -#define CMSG_LEN WSA_CMSG_LEN -#define CMSG_NXTHDR WSA_CMSG_NXTHDR - -#define SCM_RIGHTS 1 - -#define iovec _WSABUF -#define iov_base buf -#define iov_len len -#define msghdr _WSAMSG -#define msg_name name -#define msg_namelen namelen -#define msg_iov lpBuffers -#define msg_iovlen dwBufferCount -#define msg_control Control.buf -#define msg_controllen Control.len -#define msg_flags dwFlags - -int dumb_socketpair(SOCKET socks[2]); -ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags); -ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags); -int socketpair(int domain, int type, int protocol, int sv[2]); - -#endif /* SOCKET_H */ diff --git a/src/common/shims/windows/sys/time.h b/src/common/shims/windows/sys/time.h deleted file mode 100644 index 976342bd21215..0000000000000 --- a/src/common/shims/windows/sys/time.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef TIME_H -#define TIME_H - -#include /* timeval */ - -int gettimeofday_highres(struct timeval *tv, struct timezone *tz); - -static int gettimeofday(struct timeval *tv, struct timezone *tz) { - return gettimeofday_highres(tv, tz); -} - -#endif /* TIME_H */ diff --git a/src/common/shims/windows/sys/un.h b/src/common/shims/windows/sys/un.h deleted file mode 100644 index 91642683f72eb..0000000000000 --- a/src/common/shims/windows/sys/un.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef UN_H -#define UN_H - -#include - -struct sockaddr_un { - /** AF_UNIX. */ - sa_family_t sun_family; - /** The pathname. */ - char sun_path[108]; -}; - -#endif /* UN_H */ diff --git a/src/common/shims/windows/sys/wait.h b/src/common/shims/windows/sys/wait.h deleted file mode 100644 index 442218408f976..0000000000000 --- a/src/common/shims/windows/sys/wait.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef WAIT_H -#define WAIT_H - -#endif /* WAIT_H */ diff --git a/src/common/shims/windows/unistd.h b/src/common/shims/windows/unistd.h deleted file mode 100644 index aab25417e199b..0000000000000 --- a/src/common/shims/windows/unistd.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef UNISTD_H -#define UNISTD_H - -extern char *optarg; -extern int optind, opterr, optopt; -int getopt(int nargc, char *const nargv[], const char *ostr); - -#include "../../src/Win32_Interop/Win32_FDAPI.h" -#define close(...) FDAPI_close(__VA_ARGS__) - -#endif /* UNISTD_H */ diff --git a/src/common/state/actor_notification_table.cc b/src/common/state/actor_notification_table.cc deleted file mode 100644 index 19cd7fddda41e..0000000000000 --- a/src/common/state/actor_notification_table.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "actor_notification_table.h" - -#include "common_protocol.h" -#include "redis.h" - -void publish_actor_creation_notification(DBHandle *db_handle, - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id) { - // Create a flatbuffer object to serialize and publish. - flatbuffers::FlatBufferBuilder fbb; - // Create the flatbuffers message. - auto message = CreateActorCreationNotification( - fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, driver_id), - to_flatbuf(fbb, local_scheduler_id)); - fbb.Finish(message); - - ActorCreationNotificationData *data = - (ActorCreationNotificationData *) malloc( - sizeof(ActorCreationNotificationData) + fbb.GetSize()); - data->size = fbb.GetSize(); - memcpy(&data->flatbuffer_data[0], fbb.GetBufferPointer(), fbb.GetSize()); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(data), NULL, NULL, - redis_publish_actor_creation_notification, NULL); -} - -void actor_notification_table_subscribe( - DBHandle *db_handle, - actor_notification_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - ActorNotificationTableSubscribeData *sub_data = - (ActorNotificationTableSubscribeData *) malloc( - sizeof(ActorNotificationTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_actor_notification_table_subscribe, NULL); -} - -void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id) { - redis_actor_table_mark_removed(db_handle, actor_id); -} diff --git a/src/common/state/actor_notification_table.h b/src/common/state/actor_notification_table.h deleted file mode 100644 index f6aa101cd0d01..0000000000000 --- a/src/common/state/actor_notification_table.h +++ /dev/null @@ -1,74 +0,0 @@ -#ifndef ACTOR_NOTIFICATION_TABLE_H -#define ACTOR_NOTIFICATION_TABLE_H - -#include "task.h" -#include "db.h" -#include "table.h" - -/* - * ==== Subscribing to the actor notification table ==== - */ - -/* Callback for subscribing to the local scheduler table. */ -typedef void (*actor_notification_table_subscribe_callback)( - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id, - void *user_context); - -/// Publish an actor creation notification. This is published by a local -/// scheduler once it creates an actor. -/// -/// \param db_handle Database handle. -/// \param actor_id The ID of the actor that was created. -/// \param driver_id The ID of the driver that created the actor. -/// \param local_scheduler_id The ID of the local scheduler that created the -/// actor. -/// \return Void. -void publish_actor_creation_notification(DBHandle *db_handle, - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id); - -/// Data that is needed to publish an actor creation notification. -typedef struct { - /// The size of the flatbuffer object. - int64_t size; - /// The information to be sent. - uint8_t flatbuffer_data[0]; -} ActorCreationNotificationData; - -/** - * Register a callback to process actor notification events. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the local - * scheduler event happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void actor_notification_table_subscribe( - DBHandle *db_handle, - actor_notification_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register local scheduler table subscribe callbacks - * with the state database. */ -typedef struct { - actor_notification_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} ActorNotificationTableSubscribeData; - -/** - * Marks an actor as removed. This prevents the actor from being resurrected. - * - * @param db The database handle. - * @param actor_id The actor id to mark as removed. - * @return Void. - */ -void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id); - -#endif /* ACTOR_NOTIFICATION_TABLE_H */ diff --git a/src/common/state/db.h b/src/common/state/db.h deleted file mode 100644 index ac9960b89374b..0000000000000 --- a/src/common/state/db.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef DB_H -#define DB_H - -#include - -#include "common.h" -#include "event_loop.h" - -typedef struct DBHandle DBHandle; - -/** - * Connect to the global system store. - * - * @param db_address The hostname to use to connect to the database. - * @param db_port The port to use to connect to the database. - * @param db_shards_addresses The list of database shard IP addresses. - * @param db_shards_ports The list of database shard ports, in the same order - * as db_shards_addresses. - * @param client_type The type of this client. - * @param node_ip_address The hostname of the client that is connecting. - * @param args A vector of extra arguments strings. They should alternate - * between the name of the argument and the value of the argument. For - * examples: "port", "1234", "socket_name", "/tmp/s1". This vector should - * have an even length. - * @return This returns a handle to the database, which must be freed with - * db_disconnect after use. - */ -DBHandle *db_connect(const std::string &db_primary_address, - int db_primary_port, - const char *client_type, - const char *node_ip_address, - const std::vector &args); - -/** - * Attach global system store connection to an event loop. Callbacks from - * queries to the global system store will trigger events in the event loop. - * - * @param db The handle to the database that is connected. - * @param loop The event loop the database gets connected to. - * @param reattach Can only be true in unit tests. If true, the database is - * reattached to the loop. - * @return Void. - */ -void db_attach(DBHandle *db, event_loop *loop, bool reattach); - -/** - * Disconnect from the global system store. - * - * @param db The database connection to close and clean up. - * @return Void. - */ -void db_disconnect(DBHandle *db); - -/** - * Free the database handle. - * - * @param db The database connection to clean up. - * @return Void. - */ -void DBHandle_free(DBHandle *db); - -/** - * Returns the db client ID. - * - * @param db The handle to the database. - * @returns int The db client ID for this connection to the database. - */ -DBClientID get_db_client_id(DBHandle *db); - -#endif diff --git a/src/common/state/db_client_table.cc b/src/common/state/db_client_table.cc deleted file mode 100644 index b31e9d8c2d3a6..0000000000000 --- a/src/common/state/db_client_table.cc +++ /dev/null @@ -1,90 +0,0 @@ -#include "db_client_table.h" -#include "redis.h" - -void db_client_table_remove(DBHandle *db_handle, - DBClientID db_client_id, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, db_client_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_db_client_table_remove, user_context); -} - -void db_client_table_subscribe( - DBHandle *db_handle, - db_client_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context) { - DBClientTableSubscribeData *sub_data = - (DBClientTableSubscribeData *) malloc(sizeof(DBClientTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, - (table_done_callback) done_callback, - redis_db_client_table_subscribe, user_context); -} - -const std::vector db_client_table_get_ip_addresses( - DBHandle *db_handle, - const std::vector &manager_ids) { - /* We time this function because in the past this loop has taken multiple - * seconds under stressful situations on hundreds of machines causing the - * plasma manager to die (because it went too long without sending - * heartbeats). */ - int64_t start_time = current_time_ms(); - - /* Construct the manager vector from the flatbuffers object. */ - std::vector manager_vector; - - for (auto const &manager_id : manager_ids) { - DBClient client = redis_cache_get_db_client(db_handle, manager_id); - RAY_CHECK(!client.manager_address.empty()); - if (client.is_alive) { - manager_vector.push_back(client.manager_address); - } - } - - int64_t end_time = current_time_ms(); - if (end_time - start_time > RayConfig::instance().max_time_for_loop()) { - RAY_LOG(WARNING) << "calling redis_get_cached_db_client in a loop in with " - << manager_ids.size() << " manager IDs took " - << end_time - start_time << " milliseconds."; - } - - return manager_vector; -} - -void db_client_table_update_cache_callback(DBClient *db_client, - void *user_context) { - DBHandle *db_handle = (DBHandle *) user_context; - redis_cache_set_db_client(db_handle, *db_client); -} - -void db_client_table_cache_init(DBHandle *db_handle) { - db_client_table_subscribe(db_handle, db_client_table_update_cache_callback, - db_handle, NULL, NULL, NULL); -} - -DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id) { - RAY_CHECK(!client_id.is_nil()); - return redis_cache_get_db_client(db_handle, client_id); -} - -void plasma_manager_send_heartbeat(DBHandle *db_handle) { - RetryInfo heartbeat_retry; - heartbeat_retry.num_retries = 0; - heartbeat_retry.timeout = - RayConfig::instance().heartbeat_timeout_milliseconds(); - heartbeat_retry.fail_callback = NULL; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(NULL), - (RetryInfo *) &heartbeat_retry, NULL, - redis_plasma_manager_send_heartbeat, NULL); -} diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h deleted file mode 100644 index d140ba770eee2..0000000000000 --- a/src/common/state/db_client_table.h +++ /dev/null @@ -1,120 +0,0 @@ -#ifndef DB_CLIENT_TABLE_H -#define DB_CLIENT_TABLE_H - -#include - -#include "db.h" -#include "table.h" - -typedef void (*db_client_table_done_callback)(DBClientID db_client_id, - void *user_context); - -/** - * Remove a client from the db clients table. - * - * @param db_handle Database handle. - * @param db_client_id The database client ID to remove. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - * - */ -void db_client_table_remove(DBHandle *db_handle, - DBClientID db_client_id, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context); - -/* - * ==== Subscribing to the db client table ==== - */ - -/* An entry in the db client table. */ -typedef struct { - /** The database client ID. */ - DBClientID id; - /** The database client type. */ - std::string client_type; - /** An optional auxiliary address for the plasma manager associated with this - * database client. */ - std::string manager_address; - /** Whether or not the database client exists. If this is false for an entry, - * then it will never again be true. */ - bool is_alive; -} DBClient; - -/* Callback for subscribing to the db client table. */ -typedef void (*db_client_table_subscribe_callback)(DBClient *db_client, - void *user_context); - -/** - * Register a callback for a db client table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the db client - * table is updated. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void db_client_table_subscribe( - DBHandle *db_handle, - db_client_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context); - -/* Data that is needed to register db client table subscribe callbacks with the - * state database. */ -typedef struct { - db_client_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} DBClientTableSubscribeData; - -const std::vector db_client_table_get_ip_addresses( - DBHandle *db, - const std::vector &manager_ids); - -/** - * Initialize the db client cache. The cache is updated with each notification - * from the db client table. - * - * @param db_handle Database handle. - * @return Void. - */ -void db_client_table_cache_init(DBHandle *db_handle); - -/** - * Get a db client from the cache. If the requested client is not there, - * request the latest entry from the db client table. - * - * @param db_handle Database handle. - * @param client_id The ID of the client to look up in the cache. - * @return The database client in the cache. - */ -DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id); - -/* - * ==== Plasma manager heartbeats ==== - */ - -/** - * Start sending heartbeats to the plasma_managers channel. Each - * heartbeat contains this database client's ID. Heartbeats can be subscribed - * to through the plasma_managers channel. Once called, this "retries" the - * heartbeat operation forever, every heartbeat_timeout_milliseconds - * milliseconds. - * - * @param db_handle Database handle. - * @return Void. - */ -void plasma_manager_send_heartbeat(DBHandle *db_handle); - -#endif /* DB_CLIENT_TABLE_H */ diff --git a/src/common/state/driver_table.cc b/src/common/state/driver_table.cc deleted file mode 100644 index b8732e9863b20..0000000000000 --- a/src/common/state/driver_table.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "driver_table.h" -#include "redis.h" - -void driver_table_subscribe(DBHandle *db_handle, - driver_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - DriverTableSubscribeData *sub_data = - (DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_driver_table_subscribe, NULL); -} - -void driver_table_send_driver_death(DBHandle *db_handle, - WorkerID driver_id, - RetryInfo *retry) { - init_table_callback(db_handle, driver_id, __func__, - new CommonCallbackData(NULL), retry, NULL, - redis_driver_table_send_driver_death, NULL); -} diff --git a/src/common/state/driver_table.h b/src/common/state/driver_table.h deleted file mode 100644 index c8c6a6c32382c..0000000000000 --- a/src/common/state/driver_table.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef DRIVER_TABLE_H -#define DRIVER_TABLE_H - -#include "db.h" -#include "table.h" -#include "task.h" - -/* - * ==== Subscribing to the driver table ==== - */ - -/* Callback for subscribing to the driver table. */ -typedef void (*driver_table_subscribe_callback)(WorkerID driver_id, - void *user_context); - -/** - * Register a callback for a driver table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the driver event - * happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void driver_table_subscribe(DBHandle *db_handle, - driver_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register driver table subscribe callbacks with the - * state database. */ -typedef struct { - driver_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} DriverTableSubscribeData; - -/** - * Send driver death update to all subscribers. - * - * @param db_handle Database handle. - * @param driver_id The ID of the driver that died. - * @param retry Information about retrying the request to the database. - */ -void driver_table_send_driver_death(DBHandle *db_handle, - WorkerID driver_id, - RetryInfo *retry); - -#endif /* DRIVER_TABLE_H */ diff --git a/src/common/state/error_table.cc b/src/common/state/error_table.cc deleted file mode 100644 index d0fd9bdff5e9d..0000000000000 --- a/src/common/state/error_table.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "error_table.h" -#include "redis.h" - -const char *error_types[] = {"object_hash_mismatch", "put_reconstruction", - "worker_died", "actor_not_created"}; - -void push_error(DBHandle *db_handle, - DBClientID driver_id, - ErrorIndex error_type, - const std::string &error_message) { - int64_t message_size = error_message.size(); - - /* Allocate a struct to hold the error information. */ - ErrorInfo *info = (ErrorInfo *) malloc(sizeof(ErrorInfo) + message_size); - info->driver_id = driver_id; - info->error_type = error_type; - info->error_key = UniqueID::from_random(); - info->size = message_size; - memcpy(info->error_message, error_message.data(), message_size); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(info), NULL, NULL, - redis_push_error, NULL); -} diff --git a/src/common/state/error_table.h b/src/common/state/error_table.h deleted file mode 100644 index 908d7f4d0eaad..0000000000000 --- a/src/common/state/error_table.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef ERROR_TABLE_H -#define ERROR_TABLE_H - -#include "db.h" -#include "table.h" - -/// An ErrorIndex may be used as an index into error_types. -enum class ErrorIndex : int32_t { - /// An object was added with a different hash from the existing one. - OBJECT_HASH_MISMATCH = 0, - /// An object that was created through a ray.put is lost. - PUT_RECONSTRUCTION, - /// A worker died or was killed while executing a task. - WORKER_DIED, - /// An actor hasn't been created for a while. - ACTOR_NOT_CREATED, - /// The total number of error types. - MAX -}; - -/// Data that is needed to push an error. -typedef struct { - /// The ID of the driver to push the error to. - DBClientID driver_id; - /// An index into the error_types array indicating the type of the error. - ErrorIndex error_type; - /// The key to use for the error message in Redis. - UniqueID error_key; - /// The length of the error message. - int64_t size; - /// The error message. - uint8_t error_message[0]; -} ErrorInfo; - -extern const char *error_types[]; - -/// Push an error to the given Python driver. -/// -/// \param db_handle Database handle. -/// \param driver_id The ID of the Python driver to push the error to. -/// \param error_type An index specifying the type of the error. This should -/// be a value from the ErrorIndex enum. -/// \param error_message The error message to print. -/// \return Void. -void push_error(DBHandle *db_handle, - DBClientID driver_id, - ErrorIndex error_type, - const std::string &error_message); - -#endif diff --git a/src/common/state/local_scheduler_table.cc b/src/common/state/local_scheduler_table.cc deleted file mode 100644 index 075d52102807c..0000000000000 --- a/src/common/state/local_scheduler_table.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "local_scheduler_table.h" - -#include "common_protocol.h" -#include "redis.h" - -void local_scheduler_table_subscribe( - DBHandle *db_handle, - local_scheduler_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - LocalSchedulerTableSubscribeData *sub_data = - (LocalSchedulerTableSubscribeData *) malloc( - sizeof(LocalSchedulerTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_local_scheduler_table_subscribe, NULL); -} - -void local_scheduler_table_send_info(DBHandle *db_handle, - LocalSchedulerInfo *info, - RetryInfo *retry) { - /* Create a flatbuffer object to serialize and publish. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto message = CreateLocalSchedulerInfoMessage( - fbb, to_flatbuf(fbb, db_handle->client), info->total_num_workers, - info->task_queue_length, info->available_workers, - map_to_flatbuf(fbb, info->static_resources), - map_to_flatbuf(fbb, info->dynamic_resources), false); - fbb.Finish(message); - - LocalSchedulerTableSendInfoData *data = - (LocalSchedulerTableSendInfoData *) malloc( - sizeof(LocalSchedulerTableSendInfoData) + fbb.GetSize()); - data->size = fbb.GetSize(); - memcpy(&data->flatbuffer_data[0], fbb.GetBufferPointer(), fbb.GetSize()); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(data), retry, NULL, - redis_local_scheduler_table_send_info, NULL); -} - -void local_scheduler_table_disconnect(DBHandle *db_handle) { - redis_local_scheduler_table_disconnect(db_handle); -} diff --git a/src/common/state/local_scheduler_table.h b/src/common/state/local_scheduler_table.h deleted file mode 100644 index 239b84d0fa48e..0000000000000 --- a/src/common/state/local_scheduler_table.h +++ /dev/null @@ -1,98 +0,0 @@ -#ifndef LOCAL_SCHEDULER_TABLE_H -#define LOCAL_SCHEDULER_TABLE_H - -#include - -#include "db.h" -#include "table.h" -#include "task.h" - -/** This struct is sent with heartbeat messages from the local scheduler to the - * global scheduler, and it contains information about the load on the local - * scheduler. */ -typedef struct { - /** The total number of workers that are connected to this local scheduler. */ - int total_num_workers; - /** The number of tasks queued in this local scheduler. */ - int task_queue_length; - /** The number of workers that are available and waiting for tasks. */ - int available_workers; - /** The resource vector of resources generally available to this local - * scheduler. */ - std::unordered_map static_resources; - /** The resource vector of resources currently available to this local - * scheduler. */ - std::unordered_map dynamic_resources; - /** Whether the local scheduler is dead. If true, then all other fields - * should be ignored. */ - bool is_dead; -} LocalSchedulerInfo; - -/* - * ==== Subscribing to the local scheduler table ==== - */ - -/* Callback for subscribing to the local scheduler table. */ -typedef void (*local_scheduler_table_subscribe_callback)( - DBClientID client_id, - LocalSchedulerInfo info, - void *user_context); - -/** - * Register a callback for a local scheduler table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the local - * scheduler event happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void local_scheduler_table_subscribe( - DBHandle *db_handle, - local_scheduler_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register local scheduler table subscribe callbacks - * with the state database. */ -typedef struct { - local_scheduler_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} LocalSchedulerTableSubscribeData; - -/** - * Send a heartbeat to all subscribers to the local scheduler table. This - * heartbeat contains some information about the load on the local scheduler. - * - * @param db_handle Database handle. - * @param info Information about the local scheduler, including the load on the - * local scheduler. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void local_scheduler_table_send_info(DBHandle *db_handle, - LocalSchedulerInfo *info, - RetryInfo *retry); - -/* Data that is needed to publish local scheduler heartbeats to the local - * scheduler table. */ -typedef struct { - /* The size of the flatbuffer object. */ - int64_t size; - /* The information to be sent. */ - uint8_t flatbuffer_data[0]; -} LocalSchedulerTableSendInfoData; - -/** - * Send a null heartbeat to all subscribers to the local scheduler table to - * notify them that we are about to exit. This operation is performed - * synchronously. - * - * @param db_handle Database handle. - * @return Void. - */ -void local_scheduler_table_disconnect(DBHandle *db_handle); - -#endif /* LOCAL_SCHEDULER_TABLE_H */ diff --git a/src/common/state/object_table.cc b/src/common/state/object_table.cc deleted file mode 100644 index fcd527e62f6a4..0000000000000 --- a/src/common/state/object_table.cc +++ /dev/null @@ -1,119 +0,0 @@ -#include "object_table.h" -#include "redis.h" - -void object_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_object_table_lookup, user_context); -} - -void object_table_add(DBHandle *db_handle, - ObjectID object_id, - int64_t object_size, - unsigned char digest[], - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - - ObjectTableAddData *info = - (ObjectTableAddData *) malloc(sizeof(ObjectTableAddData)); - info->object_size = object_size; - memcpy(&info->digest[0], digest, DIGEST_SIZE); - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(info), retry, - (table_done_callback) done_callback, - redis_object_table_add, user_context); -} - -void object_table_remove(DBHandle *db_handle, - ObjectID object_id, - DBClientID *client_id, - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - /* Copy the client ID, if one was provided. */ - DBClientID *client_id_copy = NULL; - if (client_id != NULL) { - client_id_copy = (DBClientID *) malloc(sizeof(DBClientID)); - *client_id_copy = *client_id; - } - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(client_id_copy), retry, - (table_done_callback) done_callback, - redis_object_table_remove, user_context); -} - -void object_table_subscribe_to_notifications( - DBHandle *db_handle, - bool subscribe_all, - object_table_object_available_callback object_available_callback, - void *subscribe_context, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - ObjectTableSubscribeData *sub_data = - (ObjectTableSubscribeData *) malloc(sizeof(ObjectTableSubscribeData)); - sub_data->object_available_callback = object_available_callback; - sub_data->subscribe_context = subscribe_context; - sub_data->subscribe_all = subscribe_all; - - init_table_callback( - db_handle, ObjectID::nil(), __func__, new CommonCallbackData(sub_data), - retry, (table_done_callback) done_callback, - redis_object_table_subscribe_to_notifications, user_context); -} - -void object_table_request_notifications(DBHandle *db_handle, - int num_object_ids, - ObjectID object_ids[], - RetryInfo *retry) { - RAY_CHECK(db_handle != NULL); - RAY_CHECK(num_object_ids > 0); - ObjectTableRequestNotificationsData *data = - (ObjectTableRequestNotificationsData *) malloc( - sizeof(ObjectTableRequestNotificationsData) + - num_object_ids * sizeof(ObjectID)); - data->num_object_ids = num_object_ids; - memcpy(data->object_ids, object_ids, num_object_ids * sizeof(ObjectID)); - - init_table_callback(db_handle, ObjectID::nil(), __func__, - new CommonCallbackData(data), retry, NULL, - redis_object_table_request_notifications, NULL); -} - -void result_table_add(DBHandle *db_handle, - ObjectID object_id, - TaskID task_id, - bool is_put, - RetryInfo *retry, - result_table_done_callback done_callback, - void *user_context) { - ResultTableAddInfo *info = - (ResultTableAddInfo *) malloc(sizeof(ResultTableAddInfo)); - info->task_id = task_id; - info->is_put = is_put; - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(info), retry, - (table_done_callback) done_callback, - redis_result_table_add, user_context); -} - -void result_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - result_table_lookup_callback done_callback, - void *user_context) { - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_result_table_lookup, user_context); -} diff --git a/src/common/state/object_table.h b/src/common/state/object_table.h deleted file mode 100644 index 77a299dfd30a8..0000000000000 --- a/src/common/state/object_table.h +++ /dev/null @@ -1,242 +0,0 @@ -#ifndef OBJECT_TABLE_H -#define OBJECT_TABLE_H - -#include "common.h" -#include "table.h" -#include "db.h" -#include "task.h" - -/* - * ==== Lookup call and callback ==== - */ - -/* Callback called when the lookup completes. The callback should free - * the manager_vector array, but NOT the strings they are pointing to. If there - * was no entry at all for the object (the object had never been created - * before), then never_created will be true. - */ -typedef void (*object_table_lookup_done_callback)( - ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context); - -/* Callback called when object ObjectID is available. */ -typedef void (*object_table_object_available_callback)( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_ids, - void *user_context); - -/** - * Return the list of nodes storing object_id in their plasma stores. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object being looked up. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void object_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context); - -/* - * ==== Add object call and callback ==== - */ - -/** - * Callback called when the object add/remove operation completes. - * - * @param object_id The ID of the object that was added or removed. - * @param success Whether the operation was successful or not. If this is false - * and the operation was an addition, the object was added, but there - * was a hash mismatch. - * @param user_context The user context that was passed into the add/remove - * call. - */ -typedef void (*object_table_done_callback)(ObjectID object_id, - bool success, - void *user_context); - -/** - * Add the plasma manager that created the db_handle to the - * list of plasma managers that have the object_id. - * - * @param db_handle Handle to db. - * @param object_id Object unique identifier. - * @param data_size Object data size. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when lookup completes. - * @param user_context User context to be passed in the callbacks. - * @return Void. - */ -void object_table_add(DBHandle *db_handle, - ObjectID object_id, - int64_t object_size, - unsigned char digest[], - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context); - -/** Data that is needed to add new objects to the object table. */ -typedef struct { - int64_t object_size; - unsigned char digest[DIGEST_SIZE]; -} ObjectTableAddData; - -/* - * ==== Remove object call and callback ==== - */ - -/** - * Object remove function. - * - * @param db_handle Handle to db. - * @param object_id Object unique identifier. - * @param client_id A pointer to the database client ID to remove. If this is - * set to NULL, then the client ID associated with db_handle will be - * removed. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when lookup completes. - * @param user_context User context to be passed in the callbacks. - * @return Void. - */ -void object_table_remove(DBHandle *db_handle, - ObjectID object_id, - DBClientID *client_id, - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context); - -/* - * ==== Subscribe to be announced when new object available ==== - */ - -/** - * Set up a client-specific channel for receiving notifications about available - * objects from the object table. The callback will be called once per - * notification received on this channel. - * - * @param db_handle Handle to db. - * @param object_available_callback Callback to be called when new object - * becomes available. - * @param subscribe_context Caller context which will be passed to the - * object_available_callback. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when subscription is installed. - * This is only used for the tests. - * @param user_context User context to be passed into the done callback. This is - * only used for the tests. - * @return Void. - */ -void object_table_subscribe_to_notifications( - DBHandle *db_handle, - bool subscribe_all, - object_table_object_available_callback object_available_callback, - void *subscribe_context, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context); - -/** - * Request notifications about the availability of some objects from the object - * table. The notifications will be published to this client's object - * notification channel, which was set up by the method - * object_table_subscribe_to_notifications. - * - * @param db_handle Handle to db. - * @param object_ids The object IDs to receive notifications about. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void object_table_request_notifications(DBHandle *db, - int num_object_ids, - ObjectID object_ids[], - RetryInfo *retry); - -/** Data that is needed to run object_request_notifications requests. */ -typedef struct { - /** The number of object IDs. */ - int num_object_ids; - /** This field is used to store a variable number of object IDs. */ - ObjectID object_ids[0]; -} ObjectTableRequestNotificationsData; - -/** Data that is needed to register new object available callbacks with the - * state database. */ -typedef struct { - bool subscribe_all; - object_table_object_available_callback object_available_callback; - void *subscribe_context; -} ObjectTableSubscribeData; - -/* - * ==== Result table ==== - */ - -/** - * Callback called when the add/remove operation for a result table entry - * completes. */ -typedef void (*result_table_done_callback)(ObjectID object_id, - void *user_context); - -/** Information about a result table entry to add. */ -typedef struct { - /** The task ID of the task that created the requested object. */ - TaskID task_id; - /** True if the object was created through a put, and false if created by - * return value. */ - bool is_put; -} ResultTableAddInfo; - -/** - * Add information about a new object to the object table. This - * is immutable information like the ID of the task that - * created the object. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object to add. - * @param task_id ID of the task that creates this object. - * @param is_put A boolean that is true if the object was created through a - * ray.put, and false if the object was created by return value. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void result_table_add(DBHandle *db_handle, - ObjectID object_id, - TaskID task_id, - bool is_put, - RetryInfo *retry, - result_table_done_callback done_callback, - void *user_context); - -/** Callback called when the result table lookup completes. */ -typedef void (*result_table_lookup_callback)(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context); - -/** - * Lookup the task that created an object in the result table. The return value - * is the task ID. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object to lookup. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void result_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - result_table_lookup_callback done_callback, - void *user_context); - -#endif /* OBJECT_TABLE_H */ diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc deleted file mode 100644 index 17a3c8ce2d3a5..0000000000000 --- a/src/common/state/redis.cc +++ /dev/null @@ -1,1692 +0,0 @@ -/* Redis implementation of the global state store */ - -#include -#include -#include -#include - -extern "C" { -/* Including hiredis here is necessary on Windows for typedefs used in ae.h. */ -#include "hiredis/hiredis.h" -#include "hiredis/adapters/ae.h" -} - -#include "common.h" -#include "db.h" -#include "db_client_table.h" -#include "actor_notification_table.h" -#include "driver_table.h" -#include "local_scheduler_table.h" -#include "object_table.h" -#include "task.h" -#include "task_table.h" -#include "error_table.h" -#include "event_loop.h" -#include "redis.h" -#include "io.h" -#include "net.h" - -#include "format/common_generated.h" - -#include "common_protocol.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -#define CHECK_REDIS_CONNECT(CONTEXT_TYPE, context, M, ...) \ - do { \ - CONTEXT_TYPE *_context = (context); \ - if (!_context) { \ - RAY_LOG(FATAL) << "could not allocate redis context"; \ - } \ - if (_context->err) { \ - RAY_LOG(ERROR) << M; \ - LOG_REDIS_ERROR(_context, ""); \ - exit(-1); \ - } \ - } while (0) - -/** - * A header for callbacks of a single Redis asynchronous command. The user must - * pass in the table operation's timer ID as the asynchronous command's - * privdata field when executing the asynchronous command. The user must define - * variable names for DB and CB_DATA. After this piece of code runs, DB - * will hold a reference to the database handle, CB_DATA will hold a reference - * to the callback data for this table operation. The user must pass in the - * redisReply pointer as the REPLY argument. - * - * This header also short-circuits the entire callback if: (1) there was no - * reply from Redis, or (2) the callback data for this table operation was - * already removed, meaning that the operation was already marked as succeeded - * or failed. - */ -#define REDIS_CALLBACK_HEADER(DB, CB_DATA, REPLY) \ - if ((REPLY) == NULL) { \ - return; \ - } \ - DBHandle *DB = (DBHandle *) c->data; \ - TableCallbackData *CB_DATA = outstanding_callbacks_find((int64_t) privdata); \ - if (CB_DATA == NULL) { \ - /* the callback data structure has been \ - * already freed; just ignore this reply */ \ - return; \ - } \ - do { \ - } while (0) - -redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id) { - /* NOTE: The hash function used here must match the one in - * PyObjectID_redis_shard_hash in src/common/lib/python/common_extension.cc. - * Changes to the hash function should only be made through - * std::hash in src/common/common.h */ - std::hash index; - return db->contexts[index(id) % db->contexts.size()]; -} - -redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id) { - std::hash index; - return db->subscribe_contexts[index(id) % db->subscribe_contexts.size()]; -} - -void get_redis_shards(redisContext *context, - std::vector &db_shards_addresses, - std::vector &db_shards_ports) { - /* Get the total number of Redis shards in the system. */ - int num_attempts = 0; - redisReply *reply = NULL; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - /* Try to read the number of Redis shards from the primary shard. If the - * entry is present, exit. */ - reply = (redisReply *) redisCommand(context, "GET NumRedisShards"); - if (reply->type != REDIS_REPLY_NIL) { - break; - } - - /* Sleep for a little, and try again if the entry isn't there yet. */ - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - continue; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "No entry found for NumRedisShards"; - RAY_CHECK(reply->type == REDIS_REPLY_STRING) - << "Expected string, found Redis type " << reply->type - << " for NumRedisShards"; - int num_redis_shards = atoi(reply->str); - RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " - << "found " << num_redis_shards; - freeReplyObject(reply); - - /* Get the addresses of all of the Redis shards. */ - num_attempts = 0; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - /* Try to read the Redis shard locations from the primary shard. If we find - * that all of them are present, exit. */ - reply = (redisReply *) redisCommand(context, "LRANGE RedisShards 0 -1"); - if (static_cast(reply->elements) == num_redis_shards) { - break; - } - - /* Sleep for a little, and try again if not all Redis shard addresses have - * been added yet. */ - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - continue; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "Expected " << num_redis_shards << " Redis shard addresses, found " - << reply->elements; - - /* Parse the Redis shard addresses. */ - char db_shard_address[16]; - int db_shard_port; - for (size_t i = 0; i < reply->elements; ++i) { - /* Parse the shard addresses and ports. */ - RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING); - RAY_CHECK(parse_ip_addr_port(reply->element[i]->str, db_shard_address, - &db_shard_port) == 0); - db_shards_addresses.push_back(std::string(db_shard_address)); - db_shards_ports.push_back(db_shard_port); - } - freeReplyObject(reply); -} - -void db_connect_shard(const std::string &db_address, - int db_port, - DBClientID client, - const char *client_type, - const char *node_ip_address, - const std::vector &args, - DBHandle *db, - redisAsyncContext **context_out, - redisAsyncContext **subscribe_context_out, - redisContext **sync_context_out) { - /* Synchronous connection for initial handshake */ - redisReply *reply; - int connection_attempts = 0; - redisContext *sync_context = redisConnect(db_address.c_str(), db_port); - while (sync_context == NULL || sync_context->err) { - if (connection_attempts >= - RayConfig::instance().redis_db_connect_retries()) { - break; - } - RAY_LOG(WARNING) << "Failed to connect to Redis, retrying."; - /* Sleep for a little. */ - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - sync_context = redisConnect(db_address.c_str(), db_port); - connection_attempts += 1; - } - CHECK_REDIS_CONNECT(redisContext, sync_context, - "could not establish synchronous connection to redis " - "%s:%d", - db_address.c_str(), db_port); - /* Configure Redis to generate keyspace notifications for list events. This - * should only need to be done once (by whoever started Redis), but since - * Redis may be started in multiple places (e.g., for testing or when starting - * processes by hand), it is easier to do it multiple times. */ - reply = (redisReply *) redisCommand(sync_context, - "CONFIG SET notify-keyspace-events Kl"); - RAY_CHECK(reply != NULL) << "db_connect failed on CONFIG SET"; - freeReplyObject(reply); - /* Also configure Redis to not run in protected mode, so clients on other - * hosts can connect to it. */ - reply = - (redisReply *) redisCommand(sync_context, "CONFIG SET protected-mode no"); - RAY_CHECK(reply != NULL) << "db_connect failed on CONFIG SET"; - freeReplyObject(reply); - - /* Construct the argument arrays for RAY.CONNECT. */ - int argc = args.size() + 4; - const char **argv = (const char **) malloc(sizeof(char *) * argc); - size_t *argvlen = (size_t *) malloc(sizeof(size_t) * argc); - /* Set the command name argument. */ - argv[0] = "RAY.CONNECT"; - argvlen[0] = strlen(argv[0]); - /* Set the client ID argument. */ - argv[1] = (char *) client.data(); - argvlen[1] = sizeof(client); - /* Set the node IP address argument. */ - argv[2] = node_ip_address; - argvlen[2] = strlen(node_ip_address); - /* Set the client type argument. */ - argv[3] = client_type; - argvlen[3] = strlen(client_type); - /* Set the remaining arguments. */ - for (size_t i = 0; i < args.size(); ++i) { - argv[4 + i] = args[i].c_str(); - argvlen[4 + i] = strlen(args[i].c_str()); - } - - /* Register this client with Redis. RAY.CONNECT is a custom Redis command that - * we've defined. */ - reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen); - RAY_CHECK(reply != NULL) << "db_connect failed on RAY.CONNECT"; - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - freeReplyObject(reply); - free(argv); - free(argvlen); - - *sync_context_out = sync_context; - - /* Establish connection for control data. */ - redisAsyncContext *context = redisAsyncConnect(db_address.c_str(), db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, context, - "could not establish asynchronous connection to redis " - "%s:%d", - db_address.c_str(), db_port); - context->data = (void *) db; - *context_out = context; - - /* Establish async connection for subscription. */ - redisAsyncContext *subscribe_context = - redisAsyncConnect(db_address.c_str(), db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, subscribe_context, - "could not establish asynchronous subscription " - "connection to redis %s:%d", - db_address.c_str(), db_port); - subscribe_context->data = (void *) db; - *subscribe_context_out = subscribe_context; -} - -DBHandle *db_connect(const std::string &db_primary_address, - int db_primary_port, - const char *client_type, - const char *node_ip_address, - const std::vector &args) { - /* Check that the number of args is even. These args will be passed to the - * RAY.CONNECT Redis command, which takes arguments in pairs. */ - if (args.size() % 2 != 0) { - RAY_LOG(FATAL) << "The number of extra args must be divisible by two."; - } - - /* Create a client ID for this client. */ - DBClientID client = DBClientID::from_random(); - - DBHandle *db = new DBHandle(); - - db->client_type = strdup(client_type); - db->client = client; - - redisAsyncContext *context; - redisAsyncContext *subscribe_context; - redisContext *sync_context; - - /* Connect to the primary redis instance. */ - db_connect_shard(db_primary_address, db_primary_port, client, client_type, - node_ip_address, args, db, &context, &subscribe_context, - &sync_context); - db->context = context; - db->subscribe_context = subscribe_context; - db->sync_context = sync_context; - - /* Get the shard locations. */ - std::vector db_shards_addresses; - std::vector db_shards_ports; - get_redis_shards(db->sync_context, db_shards_addresses, db_shards_ports); - RAY_CHECK(db_shards_addresses.size() > 0) << "No Redis shards found"; - /* Connect to the shards. */ - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - db_connect_shard(db_shards_addresses[i], db_shards_ports[i], client, - client_type, node_ip_address, args, db, &context, - &subscribe_context, &sync_context); - db->contexts.push_back(context); - db->subscribe_contexts.push_back(subscribe_context); - redisFree(sync_context); - } - - return db; -} - -void DBHandle_free(DBHandle *db) { - /* Clean up the primary Redis connection state. */ - redisFree(db->sync_context); - redisAsyncFree(db->context); - redisAsyncFree(db->subscribe_context); - - /* Clean up the Redis shards. */ - RAY_CHECK(db->contexts.size() == db->subscribe_contexts.size()); - for (size_t i = 0; i < db->contexts.size(); ++i) { - redisAsyncFree(db->contexts[i]); - redisAsyncFree(db->subscribe_contexts[i]); - } - - free(db->client_type); - delete db; -} - -void db_disconnect(DBHandle *db) { - /* Notify others that this client is disconnecting from Redis. If a client of - * the same type on the same node wants to reconnect again, they must - * reconnect and get assigned a different client ID. */ - redisReply *reply = - (redisReply *) redisCommand(db->sync_context, "RAY.DISCONNECT %b", - db->client.data(), sizeof(db->client)); - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - freeReplyObject(reply); - - DBHandle_free(db); -} - -void db_attach(DBHandle *db, event_loop *loop, bool reattach) { - db->loop = loop; - /* Attach primary redis instance to the event loop. */ - int err = redisAeAttach(loop, db->context); - /* If the database is reattached in the tests, redis normally gives - * an error which we can safely ignore. */ - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - err = redisAeAttach(loop, db->subscribe_context); - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - /* Attach other redis shards to the event loop. */ - RAY_CHECK(db->contexts.size() == db->subscribe_contexts.size()); - for (size_t i = 0; i < db->contexts.size(); ++i) { - int err = redisAeAttach(loop, db->contexts[i]); - /* If the database is reattached in the tests, redis normally gives - * an error which we can safely ignore. */ - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - err = redisAeAttach(loop, db->subscribe_contexts[i]); - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - } -} - -/* - * ==== object_table callbacks ==== - */ - -void redis_object_table_add_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - bool success = (strcmp(reply->str, "hash mismatch") != 0); - if (!success) { - /* If our object hash doesn't match the one recorded in the table, report - * the error back to the user and exit immediately. */ - RAY_LOG(WARNING) << "Found objects with different value but same object " - << "ID, most likely because a nondeterministic task was " - << "executed twice, either for reconstruction or for " - << "speculation."; - } else { - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " - << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - } - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - object_table_done_callback done_callback = - (object_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, success, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_add(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectTableAddData *info = (ObjectTableAddData *) callback_data->data->Get(); - ObjectID obj_id = callback_data->id; - int64_t object_size = info->object_size; - unsigned char *digest = info->digest; - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand( - context, redis_object_table_add_callback, - (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_ADD %b %lld %b %b", - obj_id.data(), sizeof(obj_id), (long long) object_size, digest, - (size_t) DIGEST_SIZE, db->client.data(), sizeof(db->client)); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_object_table_add"); - } -} - -void redis_object_table_remove_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - if (strcmp(reply->str, "object not found") == 0) { - /* If our object entry was not in the table, it's probably a race - * condition with an object_table_add. */ - return; - } - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - object_table_done_callback done_callback = - (object_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, true, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_remove(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectID obj_id = callback_data->id; - /* If the caller provided a manager ID to delete, use it. Otherwise, use our - * own client ID as the ID to delete. */ - DBClientID *client_id = (DBClientID *) callback_data->data->Get(); - if (client_id == NULL) { - client_id = &db->client; - } - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand( - context, redis_object_table_remove_callback, - (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_REMOVE %b %b", - obj_id.data(), sizeof(obj_id), client_id->data(), sizeof(*client_id)); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_object_table_remove"); - } -} - -void redis_object_table_lookup(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - - ObjectID obj_id = callback_data->id; - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand(context, redis_object_table_lookup_callback, - (void *) callback_data->timer_id, - "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.data(), - sizeof(obj_id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in object_table lookup"); - } -} - -void redis_result_table_add_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Check that the command succeeded. */ - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strncmp(reply->str, "OK", strlen("OK")) == 0) << "reply->str is " - << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback) { - result_table_done_callback done_callback = - (result_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - destroy_timer_callback(db->loop, callback_data); -} - -void redis_result_table_add(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - ObjectID id = callback_data->id; - ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data->Get(); - int is_put = info->is_put ? 1 : 0; - - redisAsyncContext *context = get_redis_context(db, id); - - /* Add the result entry to the result table. */ - int status = - redisAsyncCommand(context, redis_result_table_add_callback, - (void *) callback_data->timer_id, - "RAY.RESULT_TABLE_ADD %b %b %d", id.data(), sizeof(id), - info->task_id.data(), sizeof(info->task_id), is_put); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "Error in result table add"); - } -} - -/* This allocates a task which must be freed by the caller, unless the returned - * task is NULL. This is used by both redis_result_table_lookup_callback and - * redis_task_table_get_task_callback. */ -Task *parse_and_construct_task_from_redis_reply(redisReply *reply) { - Task *task = NULL; - if (reply->type == REDIS_REPLY_NIL) { - /* There is no task in the reply, so return NULL. */ - } else if (reply->type == REDIS_REPLY_STRING) { - /* The reply is a flatbuffer TaskReply object. Parse it and construct the - * task. */ - auto message = flatbuffers::GetRoot(reply->str); - TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); - int64_t task_spec_size = message->task_spec()->size(); - auto execution_dependencies = - flatbuffers::GetRoot( - message->execution_dependencies()->data()); - task = Task_alloc( - spec, task_spec_size, static_cast(message->state()), - from_flatbuf(*message->local_scheduler_id()), - from_flatbuf(*execution_dependencies->execution_dependencies())); - } else { - RAY_LOG(FATAL) << "Unexpected reply type " << reply->type; - } - /* Return the task. If it is not NULL, then it must be freed by the caller. */ - return task; -} - -void redis_result_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_STRING) - << "Unexpected reply type " << reply->type << " in " - << "redis_result_table_lookup_callback"; - /* Parse the task from the reply. */ - TaskID result_id = TaskID::nil(); - bool is_put = false; - if (reply->type == REDIS_REPLY_STRING) { - auto message = flatbuffers::GetRoot(reply->str); - result_id = from_flatbuf(*message->task_id()); - is_put = message->is_put(); - } - - /* Call the done callback if there is one. */ - result_table_lookup_callback done_callback = - (result_table_lookup_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(callback_data->id, result_id, is_put, - callback_data->user_context); - } - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_result_table_lookup(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - ObjectID id = callback_data->id; - redisAsyncContext *context = get_redis_context(db, id); - int status = - redisAsyncCommand(context, redis_result_table_lookup_callback, - (void *) callback_data->timer_id, - "RAY.RESULT_TABLE_LOOKUP %b", id.data(), sizeof(id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "Error in result table lookup"); - } -} - -DBClient redis_db_client_table_get(DBHandle *db, - const unsigned char *client_id, - size_t client_id_len) { - redisReply *reply = - (redisReply *) redisCommand(db->sync_context, "HGETALL %s%b", - DB_CLIENT_PREFIX, client_id, client_id_len); - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements > 0); - DBClient db_client; - int num_fields = 0; - /* Parse the fields into a DBClient. */ - for (size_t j = 0; j < reply->elements; j = j + 2) { - const char *key = reply->element[j]->str; - const char *value = reply->element[j + 1]->str; - if (strcmp(key, "ray_client_id") == 0) { - memcpy(db_client.id.mutable_data(), value, sizeof(db_client.id)); - num_fields++; - } else if (strcmp(key, "client_type") == 0) { - db_client.client_type = std::string(value); - num_fields++; - } else if (strcmp(key, "manager_address") == 0) { - db_client.manager_address = std::string(value); - num_fields++; - } else if (strcmp(key, "deleted") == 0) { - bool is_deleted = atoi(value); - db_client.is_alive = !is_deleted; - num_fields++; - } - } - freeReplyObject(reply); - /* The client ID, type, and whether it is deleted are all - * mandatory fields. Auxiliary address is optional. */ - RAY_CHECK(num_fields >= 3); - return db_client; -} - -void redis_cache_set_db_client(DBHandle *db, DBClient client) { - db->db_client_cache[client.id] = client; -} - -/** - * Get an entry from the plasma manager table in redis. - * - * @param db The database handle. - * @param index The index of the plasma manager. - * @return The IP address and port of the manager. - */ -DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id) { - auto it = db->db_client_cache.find(db_client_id); - if (it == db->db_client_cache.end()) { - DBClient db_client = redis_db_client_table_get(db, db_client_id.data(), - sizeof(db_client_id)); - db->db_client_cache[db_client_id] = db_client; - it = db->db_client_cache.find(db_client_id); - } - return it->second; -} - -void redis_object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - RAY_LOG(DEBUG) << "Object table lookup callback"; - RAY_CHECK(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_ARRAY); - - object_table_lookup_done_callback done_callback = - (object_table_lookup_done_callback) callback_data->done_callback; - - ObjectID obj_id = callback_data->id; - - /* Parse the Redis reply. */ - if (reply->type == REDIS_REPLY_NIL) { - /* The object entry did not exist. */ - if (done_callback) { - done_callback(obj_id, true, std::vector(), - callback_data->user_context); - } - } else if (reply->type == REDIS_REPLY_ARRAY) { - /* Extract the manager IDs from the response into a vector. */ - std::vector manager_ids; - - for (size_t j = 0; j < reply->elements; ++j) { - RAY_CHECK(reply->element[j]->type == REDIS_REPLY_STRING); - DBClientID manager_id; - memcpy(manager_id.mutable_data(), reply->element[j]->str, - sizeof(manager_id)); - manager_ids.push_back(manager_id); - } - - if (done_callback) { - done_callback(obj_id, false, manager_ids, callback_data->user_context); - } - } else { - RAY_LOG(FATAL) << "Unexpected reply type from object table lookup."; - } - - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void object_table_redis_subscribe_to_notifications_callback( - redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Replies to the SUBSCRIBE command have 3 elements. There are two - * possibilities. Either the reply is the initial acknowledgment of the - * subscribe command, or it is a message. If it is the initial acknowledgment, - * then - * - reply->element[0]->str is "subscribe" - * - reply->element[1]->str is the name of the channel - * - reply->emement[2]->str is null. - * If it is an actual message, then - * - reply->element[0]->str is "message" - * - reply->element[1]->str is the name of the channel - * - reply->emement[2]->str is the contents of the message. - */ - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Object table subscribe to notifications callback, message" - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* We received an object notification. Parse the payload. */ - auto message = flatbuffers::GetRoot( - reply->element[2]->str); - /* Extract the object ID. */ - ObjectID obj_id = from_flatbuf(*message->object_id()); - /* Extract the data size. */ - int64_t data_size = message->object_size(); - int manager_count = message->manager_ids()->size(); - - /* Extract the manager IDs from the response into a vector. */ - std::vector manager_ids; - for (int i = 0; i < manager_count; ++i) { - DBClientID manager_id = from_flatbuf(*message->manager_ids()->Get(i)); - manager_ids.push_back(manager_id); - } - - /* Call the subscribe callback. */ - ObjectTableSubscribeData *data = - (ObjectTableSubscribeData *) callback_data->data->Get(); - if (data->object_available_callback) { - data->object_available_callback(obj_id, data_size, manager_ids, - data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - /* Call the done callback if there is one. This code path should only be - * used in the tests. */ - if (callback_data->done_callback != NULL) { - object_table_lookup_done_callback done_callback = - (object_table_lookup_done_callback) callback_data->done_callback; - done_callback(ray::UniqueID::nil(), false, std::vector(), - callback_data->user_context); - } - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - } else { - RAY_LOG(FATAL) << "Unexpected reply type from object table subscribe to " - << "notifications."; - } -} - -void redis_object_table_subscribe_to_notifications( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - /* The object channel prefix must match the value defined in - * src/common/redismodule/ray_redis_module.cc. */ - const char *object_channel_prefix = "OC:"; - const char *object_channel_bcast = "BCAST"; - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - int status = REDIS_OK; - /* Subscribe to notifications from the object table. This uses the client ID - * as the channel name so this channel is specific to this client. - * TODO(rkn): - * The channel name should probably be the client ID with some prefix. */ - RAY_CHECK(callback_data->data->Get() != NULL) - << "Object table subscribe data passed as NULL."; - if (((ObjectTableSubscribeData *) (callback_data->data->Get())) - ->subscribe_all) { - /* Subscribe to the object broadcast channel. */ - status = redisAsyncCommand( - db->subscribe_contexts[i], - object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%s", - object_channel_prefix, object_channel_bcast); - } else { - status = redisAsyncCommand( - db->subscribe_contexts[i], - object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b", - object_channel_prefix, db->client.data(), sizeof(db->client)); - } - - if ((status == REDIS_ERR) || db->subscribe_contexts[i]->err) { - LOG_REDIS_DEBUG(db->subscribe_contexts[i], - "error in redis_object_table_subscribe_to_notifications"); - } - } -} - -void redis_object_table_request_notifications_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_request_notifications( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectTableRequestNotificationsData *request_data = - (ObjectTableRequestNotificationsData *) callback_data->data->Get(); - int num_object_ids = request_data->num_object_ids; - ObjectID *object_ids = request_data->object_ids; - - for (int i = 0; i < num_object_ids; ++i) { - redisAsyncContext *context = get_redis_context(db, object_ids[i]); - - /* Create the arguments for the Redis command. */ - int num_args = 1 + 1 + 1; - const char **argv = (const char **) malloc(sizeof(char *) * num_args); - size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args); - /* Set the command name argument. */ - argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS"; - argvlen[0] = strlen(argv[0]); - /* Set the client ID argument. */ - argv[1] = (char *) db->client.data(); - argvlen[1] = sizeof(db->client); - /* Set the object ID arguments. */ - argv[2] = (char *) object_ids[i].data(); - argvlen[2] = sizeof(object_ids[i]); - - int status = redisAsyncCommandArgv( - context, redis_object_table_request_notifications_callback, - (void *) callback_data->timer_id, num_args, argv, argvlen); - free(argv); - free(argvlen); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, - "error in redis_object_table_subscribe_to_notifications"); - } - } -} - -/* - * ==== task_table callbacks ==== - */ - -void redis_task_table_get_task_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Parse the task from the reply. */ - Task *task = parse_and_construct_task_from_redis_reply(reply); - /* Call the done callback if there is one. */ - task_table_get_callback done_callback = - (task_table_get_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(task, callback_data->user_context); - } - /* Free the task if it is not NULL. */ - if (task != NULL) { - Task_free(task); - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_get_task(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - RAY_CHECK(callback_data->data->Get() == NULL); - TaskID task_id = callback_data->id; - - redisAsyncContext *context = get_redis_context(db, task_id); - - int status = redisAsyncCommand(context, redis_task_table_get_task_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_GET %b", task_id.data(), - sizeof(task_id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_get_task"); - } -} - -void redis_task_table_add_task_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - // If no subscribers received the message, call the failure callback. The - // caller should decide whether to retry the add. NOTE(swang): The caller - // should check whether the receiving subscriber is still alive in the - // db_client table before retrying the add. - if (reply->type == REDIS_REPLY_ERROR && - strcmp(reply->str, "No subscribers received message.") == 0) { - RAY_LOG(WARNING) << "No subscribers received the task_table_add message."; - if (callback_data->retry.fail_callback != NULL) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } - } else { - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " - << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_add_task(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - Task *task = (Task *) callback_data->data->Get(); - RAY_CHECK(task != NULL) << "NULL task passed to redis_task_table_add_task."; - - TaskID task_id = Task_task_id(task); - DBClientID local_scheduler_id = Task_local_scheduler(task); - redisAsyncContext *context = get_redis_context(db, task_id); - int state = static_cast(Task_state(task)); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = CreateTaskExecutionDependencies( - fbb, to_flatbuf(fbb, execution_spec->ExecutionDependencies())); - fbb.Finish(execution_dependencies); - - int status = redisAsyncCommand( - context, redis_task_table_add_task_callback, - (void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b %d %b", - task_id.data(), sizeof(task_id), state, local_scheduler_id.data(), - sizeof(local_scheduler_id), fbb.GetBufferPointer(), - (size_t) fbb.GetSize(), - static_cast(execution_spec->SpillbackCount()), spec, - execution_spec->SpecSize()); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task"); - } -} - -void redis_task_table_update_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - // If no subscribers received the message, call the failure callback. The - // caller should decide whether to retry the update. NOTE(swang): Retrying a - // task table update can race with the liveness monitor. Do not retry the - // update unless the caller is sure that the receiving subscriber is still - // alive in the db_client table. - if (reply->type == REDIS_REPLY_ERROR) { - RAY_LOG(WARNING) << "task_table_update failed with " << reply->str; - if (callback_data->retry.fail_callback != NULL) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } else { - RAY_LOG(FATAL) << "task_table_update failed and no fail_callback is set"; - } - } else { - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_update(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - Task *task = (Task *) callback_data->data->Get(); - RAY_CHECK(task != NULL) << "NULL task passed to redis_task_table_update."; - - TaskID task_id = Task_task_id(task); - redisAsyncContext *context = get_redis_context(db, task_id); - DBClientID local_scheduler_id = Task_local_scheduler(task); - int state = static_cast(Task_state(task)); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = CreateTaskExecutionDependencies( - fbb, to_flatbuf(fbb, execution_spec->ExecutionDependencies())); - fbb.Finish(execution_dependencies); - - int status = redisAsyncCommand( - context, redis_task_table_update_callback, - (void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b %b %d", - task_id.data(), sizeof(task_id), state, local_scheduler_id.data(), - sizeof(local_scheduler_id), fbb.GetBufferPointer(), - (size_t) fbb.GetSize(), - static_cast(execution_spec->SpillbackCount())); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_update"); - } -} - -void redis_task_table_test_and_update_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Parse the task from the reply. */ - Task *task = parse_and_construct_task_from_redis_reply(reply); - if (task == NULL) { - /* A NULL task means that the task was not in the task table. NOTE(swang): - * For normal tasks, this is not expected behavior, but actor tasks may be - * delayed when added to the task table if they are submitted to a local - * scheduler before it receives the notification that maps the actor to a - * local scheduler. */ - RAY_LOG(ERROR) << "No task found during task_table_test_and_update for " - << "task with ID " << callback_data->id; - return; - } - /* Determine whether the update happened. */ - auto message = flatbuffers::GetRoot(reply->str); - bool updated = message->updated(); - - /* Call the done callback if there is one. */ - task_table_test_and_update_callback done_callback = - (task_table_test_and_update_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(task, callback_data->user_context, updated); - } - /* Free the task if it is not NULL. */ - if (task != NULL) { - Task_free(task); - } - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_test_and_update(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - TaskID task_id = callback_data->id; - redisAsyncContext *context = get_redis_context(db, task_id); - TaskTableTestAndUpdateData *update_data = - (TaskTableTestAndUpdateData *) callback_data->data->Get(); - - int status; - /* If the test local scheduler ID is NIL, then ignore it. */ - if (update_data->test_local_scheduler_id.is_nil()) { - status = redisAsyncCommand( - context, redis_task_table_test_and_update_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.data(), - sizeof(task_id), update_data->test_state_bitmask, - update_data->update_state, update_data->local_scheduler_id.data(), - sizeof(update_data->local_scheduler_id)); - } else { - status = redisAsyncCommand( - context, redis_task_table_test_and_update_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b %b", task_id.data(), - sizeof(task_id), update_data->test_state_bitmask, - update_data->update_state, update_data->local_scheduler_id.data(), - sizeof(update_data->local_scheduler_id), - update_data->test_local_scheduler_id.data(), - sizeof(update_data->test_local_scheduler_id)); - } - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update"); - } -} - -void redis_task_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - /* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to - * PSUBSCRIBE. */ - RAY_CHECK(reply->elements == 3 || reply->elements == 4) - << "reply->elements is " << reply->elements; - /* The first element is the message type and the last entry is the payload. - * The middle one or middle two elements describe the channel that was - * published on. */ - redisReply *message_type = reply->element[0]; - redisReply *payload = reply->element[reply->elements - 1]; - if (strcmp(message_type->str, "message") == 0 || - strcmp(message_type->str, "pmessage") == 0) { - /* Handle a task table event. Parse the payload and call the callback. */ - auto message = flatbuffers::GetRoot(payload->str); - /* Extract the scheduling state. */ - TaskStatus state = static_cast(message->state()); - /* Extract the local scheduler ID. */ - DBClientID local_scheduler_id = - from_flatbuf(*message->local_scheduler_id()); - /* Extract the execution dependencies. */ - auto execution_dependencies = - flatbuffers::GetRoot( - message->execution_dependencies()->data()); - /* Extract the task spec. */ - TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); - int64_t task_spec_size = message->task_spec()->size(); - /* Extract the spillback information. */ - int spillback_count = message->spillback_count(); - /* Create a task. */ - /* Allocate the task execution spec on the stack and use it to construct - * the task. - */ - TaskExecutionSpec execution_spec( - from_flatbuf(*execution_dependencies->execution_dependencies()), spec, - task_spec_size, spillback_count); - Task *task = Task_alloc(execution_spec, state, local_scheduler_id); - - /* Call the subscribe callback if there is one. */ - TaskTableSubscribeData *data = - (TaskTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback != NULL) { - data->subscribe_callback(task, data->subscribe_context); - } - Task_free(task); - } else if (strcmp(message_type->str, "subscribe") == 0 || - strcmp(message_type->str, "psubscribe") == 0) { - /* If this condition is true, we got the initial message that acknowledged - * the subscription. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - /* Note that we do not destroy the callback data yet because the - * subscription callback needs this data. */ - remove_timer_callback(db->loop, callback_data); - } else { - RAY_LOG(FATAL) << "Unexpected reply type from task table subscribe. " - << "Message type is " << message_type->str; - } -} - -void redis_task_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - TaskTableSubscribeData *data = - (TaskTableSubscribeData *) callback_data->data->Get(); - /* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in - * sync with that file. */ - const char *TASK_CHANNEL_PREFIX = "TT:"; - /* In the new code path, subscriptions currently go through the - * primary redis shard. */ - for (auto subscribe_context : db->subscribe_contexts) { - int status; - if (data->local_scheduler_id.is_nil()) { - /* TODO(swang): Implement the state_filter by translating the bitmask into - * a Redis key-matching pattern. */ - status = redisAsyncCommand( - subscribe_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d", - TASK_CHANNEL_PREFIX, data->state_filter); - } else { - DBClientID local_scheduler_id = data->local_scheduler_id; - status = redisAsyncCommand( - subscribe_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d", - TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.data(), - sizeof(local_scheduler_id), data->state_filter); - } - if ((status == REDIS_ERR) || subscribe_context->err) { - LOG_REDIS_DEBUG(subscribe_context, "error in redis_task_table_subscribe"); - } - } -} - -/* - * ==== db client table callbacks ==== - */ - -void redis_db_client_table_remove_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Call the done callback if there is one. */ - db_client_table_done_callback done_callback = - (db_client_table_done_callback) callback_data->done_callback; - if (done_callback) { - done_callback(callback_data->id, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_db_client_table_remove(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = - redisAsyncCommand(db->context, redis_db_client_table_remove_callback, - (void *) callback_data->timer_id, "RAY.DISCONNECT %b", - callback_data->id.data(), sizeof(callback_data->id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in db_client_table_remove"); - } -} - -void redis_db_client_table_scan(DBHandle *db, - std::vector &db_clients) { - /* TODO(swang): Integrate this functionality with the Ray Redis module. To do - * this, we need the KEYS or SCAN command in Redis modules. */ - /* Get all the database client keys. */ - redisReply *reply = (redisReply *) redisCommand(db->sync_context, "KEYS %s*", - DB_CLIENT_PREFIX); - if (reply->type == REDIS_REPLY_NIL) { - return; - } - /* Get all the database client information. */ - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - for (size_t i = 0; i < reply->elements; ++i) { - /* Strip the database client table prefix. */ - unsigned char *key = (unsigned char *) reply->element[i]->str; - key += strlen(DB_CLIENT_PREFIX); - size_t key_len = reply->element[i]->len; - key_len -= strlen(DB_CLIENT_PREFIX); - /* Get the database client's information. */ - DBClient db_client = redis_db_client_table_get(db, key, key_len); - db_clients.push_back(db_client); - } - freeReplyObject(reply); -} - -void redis_db_client_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements > 2); - /* First entry is message type, then possibly the regex we psubscribed to, - * then topic, then payload. */ - redisReply *payload = reply->element[reply->elements - 1]; - /* If this condition is true, we got the initial message that acknowledged the - * subscription. */ - if (payload->str == NULL) { - if (callback_data->done_callback) { - db_client_table_done_callback done_callback = - (db_client_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - /* Note that we do not destroy the callback data yet because the - * subscription callback needs this data. */ - remove_timer_callback(db->loop, callback_data); - - /* Get the current db client table entries, in case we missed notifications - * before the initial subscription. This must be done before we process any - * notifications from the subscription channel, so that we don't readd an - * entry that has already been deleted. */ - std::vector db_clients; - redis_db_client_table_scan(db, db_clients); - /* Call the subscription callback for all entries that we missed. */ - DBClientTableSubscribeData *data = - (DBClientTableSubscribeData *) callback_data->data->Get(); - for (auto db_client : db_clients) { - data->subscribe_callback(&db_client, data->subscribe_context); - } - return; - } - /* Otherwise, parse the payload and call the callback. */ - auto message = - flatbuffers::GetRoot(payload->str); - - /* Parse the client type and auxiliary address from the response. If there is - * only client type, then the update was a delete. */ - DBClient db_client; - db_client.id = from_flatbuf(*message->db_client_id()); - db_client.client_type = std::string(message->client_type()->data()); - db_client.manager_address = std::string(message->manager_address()->data()); - db_client.is_alive = message->is_insertion(); - - /* Call the subscription callback. */ - DBClientTableSubscribeData *data = - (DBClientTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(&db_client, data->subscribe_context); - } -} - -void redis_db_client_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_db_client_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE db_clients"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in db_client_table_register_callback"); - } -} - -void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Local scheduler table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* Handle a local scheduler heartbeat. Parse the payload and call the - * subscribe callback. */ - auto message = - flatbuffers::GetRoot(reply->element[2]->str); - - /* Extract the client ID. */ - DBClientID client_id = from_flatbuf(*message->db_client_id()); - /* Extract the fields of the local scheduler info struct. */ - LocalSchedulerInfo info; - if (message->is_dead()) { - /* If the local scheduler is dead, then ignore all other fields in the - * message. */ - info.is_dead = true; - } else { - /* If the local scheduler is alive, collect load information. */ - info.is_dead = false; - info.total_num_workers = message->total_num_workers(); - info.task_queue_length = message->task_queue_length(); - info.available_workers = message->available_workers(); - - info.static_resources = map_from_flatbuf(*message->static_resources()); - info.dynamic_resources = map_from_flatbuf(*message->dynamic_resources()); - } - - /* Call the subscribe callback. */ - LocalSchedulerTableSubscribeData *data = - (LocalSchedulerTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(client_id, info, data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from local scheduler subscribe."; - } -} - -void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_local_scheduler_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE local_schedulers"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_local_scheduler_table_subscribe"); - } -} - -void redis_local_scheduler_table_send_info_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - LocalSchedulerTableSendInfoData *data = - (LocalSchedulerTableSendInfoData *) callback_data->data->Get(); - - int64_t size = data->size; - uint8_t *flatbuffer_data = data->flatbuffer_data; - - int status = redisAsyncCommand( - db->context, redis_local_scheduler_table_send_info_callback, - (void *) callback_data->timer_id, "PUBLISH local_schedulers %b", - flatbuffer_data, size); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_local_scheduler_table_send_info"); - } -} - -void redis_local_scheduler_table_disconnect(DBHandle *db) { - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - std::unordered_map empty_resource_map; - /* Most of the flatbuffer message fields don't matter here. Only the - * db_client_id and the is_dead field matter. */ - auto message = CreateLocalSchedulerInfoMessage( - fbb, to_flatbuf(fbb, db->client), 0, 0, 0, - map_to_flatbuf(fbb, empty_resource_map), - map_to_flatbuf(fbb, empty_resource_map), true); - fbb.Finish(message); - - redisReply *reply = (redisReply *) redisCommand( - db->sync_context, "PUBLISH local_schedulers %b", fbb.GetBufferPointer(), - (size_t) fbb.GetSize()); - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - freeReplyObject(reply); -} - -void redis_driver_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Driver table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* Handle a driver heartbeat. Parse the payload and call the subscribe - * callback. */ - auto message = - flatbuffers::GetRoot(reply->element[2]->str); - /* Extract the client ID. */ - WorkerID driver_id = from_flatbuf(*message->driver_id()); - - /* Call the subscribe callback. */ - DriverTableSubscribeData *data = - (DriverTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(driver_id, data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from driver subscribe."; - } -} - -void redis_driver_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_driver_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE driver_deaths"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_driver_table_subscribe"); - } -} - -void redis_driver_table_send_driver_death_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - /* At the very least, the local scheduler that publishes this message should - * also receive it. */ - RAY_CHECK(reply->integer >= 1); - - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_driver_table_send_driver_death(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - WorkerID driver_id = callback_data->id; - - /* Create a flatbuffer object to serialize and publish. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto message = CreateDriverTableMessage(fbb, to_flatbuf(fbb, driver_id)); - fbb.Finish(message); - - int status = redisAsyncCommand( - db->context, redis_driver_table_send_driver_death_callback, - (void *) callback_data->timer_id, "PUBLISH driver_deaths %b", - fbb.GetBufferPointer(), (size_t) fbb.GetSize()); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_driver_table_send_driver_death"); - } -} - -void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - /* NOTE(swang): We purposefully do not provide a callback, leaving the table - * operation and timer active. This allows us to send a new heartbeat every - * heartbeat_timeout_milliseconds without having to allocate and deallocate - * memory for callback data each time. */ - int status = redisAsyncCommand( - db->context, NULL, (void *) callback_data->timer_id, - "PUBLISH plasma_managers %b", db->client.data(), sizeof(db->client)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_plasma_manager_send_heartbeat"); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_publish_actor_creation_notification_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - // At the very least, the local scheduler that publishes this message should - // also receive it. - RAY_CHECK(reply->integer >= 1); - - RAY_CHECK(callback_data->done_callback == NULL); - // Clean up the timer and callback. - destroy_timer_callback(db->loop, callback_data); -} - -void redis_publish_actor_creation_notification( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ActorCreationNotificationData *data = - (ActorCreationNotificationData *) callback_data->data->Get(); - - int status = redisAsyncCommand( - db->context, redis_publish_actor_creation_notification_callback, - (void *) callback_data->timer_id, "PUBLISH actor_notifications %b", - &data->flatbuffer_data[0], data->size); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_publish_actor_creation_notification"); - } -} - -void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Local scheduler table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - // Handle an actor notification message. Parse the payload and call the - // subscribe callback. - redisReply *payload = reply->element[2]; - ActorNotificationTableSubscribeData *data = - (ActorNotificationTableSubscribeData *) callback_data->data->Get(); - - auto message = - flatbuffers::GetRoot(payload->str); - ActorID actor_id = from_flatbuf(*message->actor_id()); - WorkerID driver_id = from_flatbuf(*message->driver_id()); - DBClientID local_scheduler_id = - from_flatbuf(*message->local_scheduler_id()); - - if (data->subscribe_callback) { - data->subscribe_callback(actor_id, driver_id, local_scheduler_id, - data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from actor notification " - << "subscribe."; - } -} - -void redis_actor_notification_table_subscribe( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_actor_notification_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE actor_notifications"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_actor_notification_table_subscribe"); - } -} - -void redis_actor_table_mark_removed(DBHandle *db, ActorID actor_id) { - int status = - redisAsyncCommand(db->context, NULL, NULL, "HSET Actor:%b removed \"1\"", - actor_id.data(), sizeof(actor_id)); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_actor_table_mark_removed"); - } -} - -void redis_push_error_rpush_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* The reply should be the length of the errors list after our RPUSH. */ - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - destroy_timer_callback(db->loop, callback_data); -} - -void redis_push_error_hmset_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - /* Make sure we were able to add the error information. */ - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Add the error to this driver's list of errors. */ - ErrorInfo *info = (ErrorInfo *) callback_data->data->Get(); - int status = redisAsyncCommand( - db->context, redis_push_error_rpush_callback, - (void *) callback_data->timer_id, "RPUSH ErrorKeys Error:%b:%b", - info->driver_id.data(), sizeof(info->driver_id), info->error_key.data(), - sizeof(info->error_key)); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error rpush"); - } -} - -void redis_push_error(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - ErrorInfo *info = (ErrorInfo *) callback_data->data->Get(); - RAY_CHECK(info->error_type < ErrorIndex::MAX && - info->error_type >= ErrorIndex::OBJECT_HASH_MISMATCH); - /// Look up the error type. - const char *error_type = error_types[static_cast(info->error_type)]; - - /* Set the error information. */ - int status = redisAsyncCommand( - db->context, redis_push_error_hmset_callback, - (void *) callback_data->timer_id, - "HMSET Error:%b:%b type %s message %b data %b", info->driver_id.data(), - sizeof(info->driver_id), info->error_key.data(), sizeof(info->error_key), - error_type, info->error_message, info->size, "None", strlen("None")); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error hmset"); - } -} - -DBClientID get_db_client_id(DBHandle *db) { - RAY_CHECK(db != NULL); - return db->client; -} diff --git a/src/common/state/redis.h b/src/common/state/redis.h deleted file mode 100644 index 164069740d3e1..0000000000000 --- a/src/common/state/redis.h +++ /dev/null @@ -1,356 +0,0 @@ -#ifndef REDIS_H -#define REDIS_H - -#include - -#include "db.h" -#include "db_client_table.h" -#include "object_table.h" -#include "task_table.h" - -#include "hiredis/hiredis.h" -#include "hiredis/async.h" - -#define LOG_REDIS_ERROR(context, M, ...) \ - RAY_LOG(ERROR) << "Redis error " << context->err << " " << context->errstr \ - << "; " << M - -#define LOG_REDIS_DEBUG(context, M, ...) \ - RAY_LOG(DEBUG) << "Redis error " << context->err << " " << context->errstr \ - << "; " << M; - -struct DBHandle { - /** String that identifies this client type. */ - char *client_type; - /** Unique ID for this client. */ - DBClientID client; - /** Primary redis context for all non-subscribe connections. This is used for - * the database client table, heartbeats, and errors that should be pushed to - * the driver. */ - redisAsyncContext *context; - /** Primary redis context for "subscribe" communication. A separate context - * is needed for this communication (see - * https://github.com/redis/hiredis/issues/55). This is used for the - * database client table, heartbeats, and errors that should be pushed to - * the driver. */ - redisAsyncContext *subscribe_context; - /** Redis contexts for shards for all non-subscribe connections. All requests - * to the object table, task table, and event table should be directed here. - * The correct shard can be retrieved using get_redis_context below. */ - std::vector contexts; - /** Redis contexts for shards for "subscribe" communication. All requests - * to the object table, task table, and event table should be directed here. - * The correct shard can be retrieved using get_redis_context below. */ - std::vector subscribe_contexts; - /** The event loop this global state store connection is part of. */ - event_loop *loop; - /** Index of the database connection in the event loop */ - int64_t db_index; - /** Cache for the IP addresses of db clients. This is an unordered map mapping - * client IDs to addresses. */ - std::unordered_map db_client_cache; - /** Redis context for synchronous connections. This should only be used very - * rarely, it is not asynchronous. */ - redisContext *sync_context; -}; - -/** - * Get the Redis asynchronous context responsible for non-subscription - * communication for the given UniqueID. - * - * @param db The database handle. - * @param id The ID whose location we are querying for. - * @return The redisAsyncContext responsible for the given ID. - */ -redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id); - -/** - * Get the Redis asynchronous context responsible for subscription - * communication for the given UniqueID. - * - * @param db The database handle. - * @param id The ID whose location we are querying for. - * @return The redisAsyncContext responsible for the given ID. - */ -redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id); - -/** - * Get a list of Redis shard IP addresses from the primary shard. - * - * @param context A Redis context connected to the primary shard. - * @param db_shards_addresses The IP addresses for the shards registered - * with the primary shard will be added to this vector. - * @param db_shards_ports The IP ports for the shards registered with the - * primary shard will be added to this vector, in the same order as - * db_shards_addresses. - */ -void get_redis_shards(redisContext *context, - std::vector &db_shards_addresses, - std::vector &db_shards_ports); - -void redis_cache_set_db_client(DBHandle *db, DBClient client); - -DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id); - -void redis_object_table_get_entry(redisAsyncContext *c, - void *r, - void *privdata); - -void object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/* - * ==== Redis object table functions ==== - */ - -/** - * Lookup object table entry in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_lookup(TableCallbackData *callback_data); - -/** - * Add a location entry to the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_add(TableCallbackData *callback_data); - -/** - * Remove a location entry from the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_remove(TableCallbackData *callback_data); - -/** - * Create a client-specific channel for receiving notifications from the object - * table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_subscribe_to_notifications( - TableCallbackData *callback_data); - -/** - * Request notifications about when certain objects become available. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_request_notifications(TableCallbackData *callback_data); - -/** - * Add a new object to the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_result_table_add(TableCallbackData *callback_data); - -/** - * Lookup the task that created the object in redis. The result is the task ID. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_result_table_lookup(TableCallbackData *callback_data); - -/** - * Callback invoked when the reply from the object table lookup command is - * received. - * - * @param c Redis context. - * @param r Reply. - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/* - * ==== Redis task table function ===== - */ - -/** - * Get a task table entry, including the task spec and the task's scheduling - * information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_get_task(TableCallbackData *callback_data); - -/** - * Add a task table entry with a new task spec and the task's scheduling - * information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_add_task(TableCallbackData *callback_data); - -/** - * Update a task table entry with the task's scheduling information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_update(TableCallbackData *callback_data); - -/** - * Update a task table entry with the task's scheduling information, if the - * task's current scheduling information matches the test value. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_test_and_update(TableCallbackData *callback_data); - -/** - * Callback invoked when the reply from the task push command is received. - * - * @param c Redis context. - * @param r Reply (not used). - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_task_table_publish_push_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/** - * Callback invoked when the reply from the task publish command is received. - * - * @param c Redis context. - * @param r Reply (not used). - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_task_table_publish_publish_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/** - * Subscribe to updates of the task table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_subscribe(TableCallbackData *callback_data); - -/** - * Remove a client from the db clients table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_db_client_table_remove(TableCallbackData *callback_data); - -/** - * Subscribe to updates from the db client table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_db_client_table_subscribe(TableCallbackData *callback_data); - -/** - * Subscribe to updates from the local scheduler table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data); - -/** - * Publish an update to the local scheduler table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_local_scheduler_table_send_info(TableCallbackData *callback_data); - -/** - * Synchronously publish a null update to the local scheduler table signifying - * that we are about to exit. - * - * @param db The database handle of the dying local scheduler. - * @return Void. - */ -void redis_local_scheduler_table_disconnect(DBHandle *db); - -/** - * Subscribe to updates from the driver table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_driver_table_subscribe(TableCallbackData *callback_data); - -/** - * Publish an update to the driver table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_driver_table_send_driver_death(TableCallbackData *callback_data); - -void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data); - -/** - * Marks an actor as removed. This prevents the actor from being resurrected. - * - * @param db The database handle. - * @param actor_id The actor id to mark as removed. - * @return Void. - */ -void redis_actor_table_mark_removed(DBHandle *db, ActorID actor_id); - -/// Publish an actor creation notification. -/// -/// \param callback_data Data structure containing redis connection and timeout -/// information. -/// \return Void. -void redis_publish_actor_creation_notification( - TableCallbackData *callback_data); - -/** - * Subscribe to updates about newly created actors. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_actor_notification_table_subscribe(TableCallbackData *callback_data); - -void redis_object_info_subscribe(TableCallbackData *callback_data); - -void redis_push_error(TableCallbackData *callback_data); - -#endif /* REDIS_H */ diff --git a/src/common/state/table.cc b/src/common/state/table.cc deleted file mode 100644 index 8269c2b1e7396..0000000000000 --- a/src/common/state/table.cc +++ /dev/null @@ -1,200 +0,0 @@ -#include "table.h" - -#include -#include -#include "redis.h" - -BaseCallbackData::BaseCallbackData(void *data) { - data_ = data; -} - -BaseCallbackData::~BaseCallbackData(void) {} - -void *BaseCallbackData::Get(void) { - return data_; -} - -CommonCallbackData::CommonCallbackData(void *data) : BaseCallbackData(data) {} - -CommonCallbackData::~CommonCallbackData(void) { - free(data_); -} - -TaskCallbackData::TaskCallbackData(Task *task_data) - : BaseCallbackData(task_data) {} - -TaskCallbackData::~TaskCallbackData(void) { - Task *task = (Task *) data_; - Task_free(task); -} - -/* The default behavior is to retry every ten seconds forever. */ -static const RetryInfo default_retry = {.num_retries = -1, - .timeout = 10000, - .fail_callback = NULL}; - -static int64_t callback_data_id = 0; - -TableCallbackData *init_table_callback(DBHandle *db_handle, - UniqueID id, - const char *label, - OWNER BaseCallbackData *data, - RetryInfo *retry, - table_done_callback done_callback, - table_retry_callback retry_callback, - void *user_context) { - RAY_CHECK(db_handle); - RAY_CHECK(db_handle->loop); - RAY_CHECK(data); - /* If no retry info is provided, use the default retry info. */ - if (retry == NULL) { - retry = (RetryInfo *) &default_retry; - } - RAY_CHECK(retry); - /* Allocate and initialize callback data structure for object table */ - TableCallbackData *callback_data = - (TableCallbackData *) malloc(sizeof(TableCallbackData)); - RAY_CHECK(callback_data != NULL) << "Memory allocation error!"; - callback_data->id = id; - callback_data->label = label; - callback_data->retry = *retry; - callback_data->done_callback = done_callback; - callback_data->retry_callback = retry_callback; - callback_data->data = data; - callback_data->requests_info = NULL; - callback_data->user_context = user_context; - callback_data->db_handle = db_handle; - /* TODO(ekl) set a retry timer once we've figured out the retry conditions - * and have a solution to the O(n^2) ae timers issue. For now, use a dummy - * timer id to uniquely id this callback. */ - callback_data->timer_id = callback_data_id++; - outstanding_callbacks_add(callback_data); - - RAY_LOG(DEBUG) << "Initializing table command " << callback_data->label - << " with timer ID " << callback_data->timer_id; - callback_data->retry_callback(callback_data); - - return callback_data; -} - -void destroy_timer_callback(event_loop *loop, - TableCallbackData *callback_data) { - /* This is commented out because we no longer add timers to the event loop for - * each Redis command. */ - // event_loop_remove_timer(loop, callback_data->timer_id); - destroy_table_callback(callback_data); -} - -void remove_timer_callback(event_loop *loop, TableCallbackData *callback_data) { - /* This is commented out because we no longer add timers to the event loop for - * each Redis command. */ - // event_loop_remove_timer(loop, callback_data->timer_id); -} - -void destroy_table_callback(TableCallbackData *callback_data) { - RAY_CHECK(callback_data != NULL); - - if (callback_data->requests_info) - free(callback_data->requests_info); - - RAY_CHECK(callback_data->data != NULL); - delete callback_data->data; - callback_data->data = NULL; - - outstanding_callbacks_remove(callback_data); - - /* Timer is removed via EVENT_LOOP_TIMER_DONE in the timeout callback. */ - free(callback_data); -} - -int64_t table_timeout_handler(event_loop *loop, - int64_t timer_id, - void *user_context) { - RAY_CHECK(loop != NULL); - RAY_CHECK(user_context != NULL); - TableCallbackData *callback_data = (TableCallbackData *) user_context; - - RAY_CHECK(callback_data->retry.num_retries >= 0 || - callback_data->retry.num_retries == -1); - RAY_LOG(WARNING) << "retrying operation " << callback_data->label - << ", retry_count = " << callback_data->retry.num_retries; - - if (callback_data->retry.num_retries == 0) { - /* We didn't get a response from the database after exhausting all retries; - * let user know, cleanup the state, and remove the timer. */ - RAY_LOG(WARNING) << "Table command " << callback_data->label - << " with timer ID " << timer_id << " failed"; - if (callback_data->retry.fail_callback) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } - destroy_table_callback(callback_data); - return EVENT_LOOP_TIMER_DONE; - } - - /* Decrement retry count and try again. We use -1 to indicate infinite - * retries. */ - if (callback_data->retry.num_retries != -1) { - callback_data->retry.num_retries--; - } - callback_data->retry_callback(callback_data); - return callback_data->retry.timeout; -} - -/** - * Unordered map maintaining the outstanding callbacks. - * - * This unordered map is used to handle the following case: - * - a table command is issued with an associated callback and a callback data - * structure; - * - the last timeout associated to this command expires, as a result the - * callback data structure is freed; - * - a reply arrives, but now the callback data structure is gone, so we have - * to ignore this reply; - * - * This unordered map enables us to ignore such replies. The operations on the - * unordered map are as follows. - * - * When we issue a table command and a timeout event to wait for the reply, we - * add a new entry to the unordered map that is keyed by the ID of the timer. - * Note that table commands must have unique timer IDs, which are assigned by - * the Redis ae event loop. - * - * When we receive the reply, we check whether the callback still exists in - * this unordered map, and if not we just ignore the reply. If the callback does - * exist, the reply receiver is responsible for removing the timer and the - * entry associated to the callback, or else the timeout handler will continue - * firing. - * - * When the last timeout associated to the command expires we remove the entry - * associated to the callback. - */ -static std::unordered_map outstanding_callbacks; - -void outstanding_callbacks_add(TableCallbackData *callback_data) { - outstanding_callbacks[callback_data->timer_id] = callback_data; -} - -TableCallbackData *outstanding_callbacks_find(int64_t key) { - auto it = outstanding_callbacks.find(key); - if (it != outstanding_callbacks.end()) { - return it->second; - } - return NULL; -} - -void outstanding_callbacks_remove(TableCallbackData *callback_data) { - outstanding_callbacks.erase(callback_data->timer_id); -} - -void destroy_outstanding_callbacks(event_loop *loop) { - /* We have to be careful because destroy_timer_callback modifies - * outstanding_callbacks in place */ - auto it = outstanding_callbacks.begin(); - while (it != outstanding_callbacks.end()) { - auto next_it = std::next(it, 1); - destroy_timer_callback(loop, it->second); - it = next_it; - } -} diff --git a/src/common/state/table.h b/src/common/state/table.h deleted file mode 100644 index 1fadcf339cef3..0000000000000 --- a/src/common/state/table.h +++ /dev/null @@ -1,216 +0,0 @@ -#ifndef TABLE_H -#define TABLE_H - -#include "common.h" -#include "task.h" -#include "db.h" - -typedef struct TableCallbackData TableCallbackData; - -/* An abstract class for any data passed by the user into a table operation. - * This class wraps arbitrary pointers and allows the caller to define a custom - * destructor, for data that is not allocated with malloc. */ -class BaseCallbackData { - public: - BaseCallbackData(void *data); - virtual ~BaseCallbackData(void) = 0; - - /* Return the pointer to the data. */ - void *Get(void); - - protected: - /* The pointer to the data. */ - void *data_; -}; - -/* A common class for malloc'ed data passed by the user into a table operation. - * This should ONLY be used when only a free is necessary. */ -class CommonCallbackData : public BaseCallbackData { - public: - CommonCallbackData(void *data); - ~CommonCallbackData(void); -}; - -/* A class for Task data passed by the user into a table operation. This calls - * task cleanup in the destructor. */ -class TaskCallbackData : public BaseCallbackData { - public: - TaskCallbackData(Task *task_data); - ~TaskCallbackData(void); -}; - -typedef void *table_done_callback; - -/* The callback called when the database operation hasn't completed after - * the number of retries specified for the operation. - * - * @param id The unique ID that identifies this callback. Examples include an - * object ID or task ID. - * @param user_context The state context for the callback. This is equivalent - * to the user_context field in TableCallbackData. - * @param user_data A data argument for the callback. This is equivalent to the - * data field in TableCallbackData. The user is responsible for - * freeing user_data. - */ -typedef void (*table_fail_callback)(UniqueID id, - void *user_context, - void *user_data); - -typedef void (*table_retry_callback)(TableCallbackData *callback_data); - -/** - * Data structure consolidating the retry related variables. If a NULL - * RetryInfo struct is used, the default behavior will be to retry infinitely - * many times. - */ -typedef struct { - /** Number of retries. This field will be decremented every time a retry - * occurs (unless the value is -1). If this value is -1, then there will be - * infinitely many retries. */ - int num_retries; - /** Timeout, in milliseconds. */ - uint64_t timeout; - /** The callback that will be called if there are no more retries left. */ - table_fail_callback fail_callback; -} RetryInfo; - -struct TableCallbackData { - /** ID of the entry in the table that we are going to look up, remove or add. - */ - UniqueID id; - /** A label to identify the original request for logging purposes. */ - const char *label; - /** The callback that will be called when results is returned. */ - table_done_callback done_callback; - /** The callback that will be called to initiate the next try. */ - table_retry_callback retry_callback; - /** Retry information containing the remaining number of retries, the timeout - * before the next retry, and a pointer to the failure callback. - */ - RetryInfo retry; - /** Pointer to the data that is entered into the table. This can be used to - * pass the result of the call to the callback. The callback takes ownership - * over this data and will free it. */ - BaseCallbackData *data; - /** Pointer to the data used internally to handle multiple database requests. - */ - void *requests_info; - /** User context. */ - void *user_context; - /** Handle to db. */ - DBHandle *db_handle; - /** Handle to timer. */ - int64_t timer_id; -}; - -/** - * Function to handle the timeout event. - * - * @param loop Event loop. - * @param timer_id Timer identifier. - * @param context Pointer to the callback data for the object table - * @return Timeout to reset the timer if we need to try again, or - * EVENT_LOOP_TIMER_DONE if retry_count == 0. - */ -int64_t table_timeout_handler(event_loop *loop, - int64_t timer_id, - void *context); - -/** - * Initialize the table callback and call the retry_callback for the first time. - * - * @param db_handle Database handle. - * @param id ID of the object that is looked up, added or removed. - * @param label A string label to identify the type of table request for - * logging purposes. - * @param data Data entered into the table. Shall be freed by the user. Caller - * must specify a destructor by wrapping a void *pointer in a - * BaseCallbackData class. - * @param retry Retry relevant information: retry timeout, number of remaining - * retries, and retry callback. - * @param done_callback Function to be called when database returns result. - * @param fail_callback Function to be called when number of retries is - * exhausted. - * @param user_context Context that can be provided by the user and will be - * passed on to the various callbacks. - * @return New table callback data struct. - */ -TableCallbackData *init_table_callback(DBHandle *db_handle, - UniqueID id, - const char *label, - OWNER BaseCallbackData *data, - RetryInfo *retry, - table_done_callback done_callback, - table_retry_callback retry_callback, - void *user_context); - -/** - * Destroy any state associated with the callback data. This removes all - * associated state from the outstanding callbacks unordered map and frees any - * associated memory. This does not remove any associated timer events. - * - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void destroy_table_callback(TableCallbackData *callback_data); - -/** - * Destroy all state events associated with the callback data, including memory - * and timer events. - * - * @param loop The event loop. - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void destroy_timer_callback(event_loop *loop, TableCallbackData *callback_data); - -/** - * Remove the callback timer without destroying the callback data. - * - * @param loop The event loop. - * @param callback_data The pointer to the data structure of the callback. - * @return Void. - */ -void remove_timer_callback(event_loop *loop, TableCallbackData *callback_data); - -/** - * Add an outstanding callback entry. - * - * @param callback_data The pointer to the data structure of the callback we - * want to insert. - * @return None. - */ -void outstanding_callbacks_add(TableCallbackData *callback_data); - -/** - * Find an outstanding callback entry. - * - * @param key The key for the outstanding callbacks unordered map. We use the - * timer ID assigned by the Redis ae event loop. - * @return Returns the callback data if found, NULL otherwise. - */ -TableCallbackData *outstanding_callbacks_find(int64_t key); - -/** - * Remove an outstanding callback entry. This only removes the callback entry - * from the unordered map. It does not free the entry or remove any associated - * timer events. - * - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void outstanding_callbacks_remove(TableCallbackData *callback_data); - -/** - * Destroy all outstanding callbacks and remove their associated timer events - * from the event loop. - * - * @param loop The event loop from which we want to remove the timer events. - * @return Void. - */ -void destroy_outstanding_callbacks(event_loop *loop); - -#endif /* TABLE_H */ diff --git a/src/common/state/task_table.cc b/src/common/state/task_table.cc deleted file mode 100644 index 514350b08353c..0000000000000 --- a/src/common/state/task_table.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "task_table.h" -#include "redis.h" - -#define NUM_DB_REQUESTS 2 - -void task_table_get_task(DBHandle *db_handle, - TaskID task_id, - RetryInfo *retry, - task_table_get_callback get_callback, - void *user_context) { - init_table_callback( - db_handle, task_id, __func__, new CommonCallbackData(NULL), retry, - (void *) get_callback, redis_task_table_get_task, user_context); -} - -void task_table_add_task(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, Task_task_id(task), __func__, - new TaskCallbackData(task), retry, - (table_done_callback) done_callback, - redis_task_table_add_task, user_context); -} - -void task_table_update(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, Task_task_id(task), __func__, - new TaskCallbackData(task), retry, - (table_done_callback) done_callback, - redis_task_table_update, user_context); -} - -void task_table_test_and_update( - DBHandle *db_handle, - TaskID task_id, - DBClientID test_local_scheduler_id, - TaskStatus test_state_bitmask, - TaskStatus update_state, - RetryInfo *retry, - task_table_test_and_update_callback done_callback, - void *user_context) { - TaskTableTestAndUpdateData *update_data = - (TaskTableTestAndUpdateData *) malloc(sizeof(TaskTableTestAndUpdateData)); - update_data->test_local_scheduler_id = test_local_scheduler_id; - update_data->test_state_bitmask = test_state_bitmask; - update_data->update_state = update_state; - /* Update the task entry's local scheduler with this client's ID. */ - update_data->local_scheduler_id = db_handle->client; - init_table_callback(db_handle, task_id, __func__, - new CommonCallbackData(update_data), retry, - (table_done_callback) done_callback, - redis_task_table_test_and_update, user_context); -} - -/* TODO(swang): A corresponding task_table_unsubscribe. */ -void task_table_subscribe(DBHandle *db_handle, - DBClientID local_scheduler_id, - TaskStatus state_filter, - task_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - TaskTableSubscribeData *sub_data = - (TaskTableSubscribeData *) malloc(sizeof(TaskTableSubscribeData)); - sub_data->local_scheduler_id = local_scheduler_id; - sub_data->state_filter = state_filter; - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, local_scheduler_id, __func__, - new CommonCallbackData(sub_data), retry, - (table_done_callback) done_callback, - redis_task_table_subscribe, user_context); -} diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h deleted file mode 100644 index 3884ddece8931..0000000000000 --- a/src/common/state/task_table.h +++ /dev/null @@ -1,190 +0,0 @@ -#ifndef task_table_H -#define task_table_H - -#include "db.h" -#include "table.h" -#include "task.h" - -/** - * The task table is a message bus that is used for communication between local - * and global schedulers (and also persisted to the state database). Here are - * examples of events that are recorded by the task table: - * - * 1) Local schedulers write to it when submitting a task to the global - * scheduler. - * 2) The global scheduler subscribes to updates to the task table to get tasks - * submitted by local schedulers. - * 3) The global scheduler writes to it when assigning a task to a local - * scheduler. - * 4) Local schedulers subscribe to updates to the task table to get tasks - * assigned to them by the global scheduler. - * 5) Local schedulers write to it when a task finishes execution. - */ - -/* Callback called when a task table write operation completes. */ -typedef void (*task_table_done_callback)(TaskID task_id, void *user_context); - -/* Callback called when a task table read operation completes. If the task ID - * was not in the task table, then the task pointer will be NULL. */ -typedef void (*task_table_get_callback)(Task *task, void *user_context); - -/* Callback called when a task table test-and-update operation completes. If - * the task ID was not in the task table, then the task pointer will be NULL. - * If the update succeeded, the updated field will be set to true. */ -typedef void (*task_table_test_and_update_callback)(Task *task, - void *user_context, - bool updated); - -/** - * Get a task's entry from the task table. - * - * @param db_handle Database handle. - * @param task_id The ID of the task we want to look up. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_get_task(DBHandle *db, - TaskID task_id, - RetryInfo *retry, - task_table_get_callback get_callback, - void *user_context); - -/** - * Add a task entry, including task spec and scheduling information, to the task - * table. This will overwrite any task already in the task table with the same - * task ID. - * - * @param db_handle Database handle. - * @param task The task entry to add to the table. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_add_task(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/* - * ==== Publish the task table ==== - */ - -/** - * Update a task's scheduling information in the task table. This assumes that - * the task spec already exists in the task table entry. - * - * @param db_handle Database handle. - * @param task The task entry to add to the table. The task spec in the entry is - * ignored. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_update(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/** - * Update a task's scheduling information in the task table, if the current - * value matches the given test value. If the update succeeds, it also updates - * the task entry's local scheduler ID with the ID of the client who called - * this function. This assumes that the task spec already exists in the task - * table entry. - * - * @param db_handle Database handle. - * @param task_id The task ID of the task entry to update. - * @param test_local_scheduler_id The local scheduler ID to test the current - * local scheduler ID against. If not NIL_ID, and if the current local - * scheduler ID does not match it, then the update will not happen. - * @param test_state_bitmask The bitmask to apply to the task entry's current - * scheduling state. The update happens if and only if the current - * scheduling state AND-ed with the bitmask is greater than 0 and the - * local scheduler ID test passes. - * @param update_state The value to update the task entry's scheduling state - * with, if the current state matches test_state_bitmask. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_test_and_update( - DBHandle *db_handle, - TaskID task_id, - DBClientID test_local_scheduler_id, - TaskStatus test_state_bitmask, - TaskStatus update_state, - RetryInfo *retry, - task_table_test_and_update_callback done_callback, - void *user_context); - -/* Data that is needed to test and set the task's scheduling state. */ -typedef struct { - /** The value to test the current local scheduler ID against. This field is - * ignored if equal to NIL_ID. */ - DBClientID test_local_scheduler_id; - TaskStatus test_state_bitmask; - TaskStatus update_state; - DBClientID local_scheduler_id; -} TaskTableTestAndUpdateData; - -/* - * ==== Subscribing to the task table ==== - */ - -/* Callback for subscribing to the task table. */ -typedef void (*task_table_subscribe_callback)(Task *task, void *user_context); - -/** - * Register a callback for a task event. An event is any update of a task in - * the task table, produced by task_table_add_task or task_table_add_task. - * Events include changes to the task's scheduling state or changes to the - * task's local scheduler ID. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the task table is - * updated. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param local_scheduler_id The db_client_id of the local scheduler whose - * events we want to listen to. If you want to subscribe to updates from - * all local schedulers, pass in NIL_ID. - * @param state_filter Events we want to listen to. Can have values from the - * enum "scheduling_state" in task.h. - * TODO(pcm): Make it possible to combine these using flags like - * TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_subscribe(DBHandle *db_handle, - DBClientID local_scheduler_id, - TaskStatus state_filter, - task_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/* Data that is needed to register task table subscribe callbacks with the state - * database. */ -typedef struct { - DBClientID local_scheduler_id; - TaskStatus state_filter; - task_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} TaskTableSubscribeData; - -#endif /* task_table_H */ diff --git a/src/common/task.cc b/src/common/task.cc deleted file mode 100644 index 60110fe225432..0000000000000 --- a/src/common/task.cc +++ /dev/null @@ -1,606 +0,0 @@ -#include - -#include "common_protocol.h" - -#include "task.h" - -extern "C" { -#include "sha256.h" -} - -ObjectID task_compute_return_id(TaskID task_id, int64_t return_index) { - /* Here, return_indices need to be >= 0, so we can use negative - * indices for put. */ - RAY_DCHECK(return_index >= 0); - /* TODO(rkn): This line requires object and task IDs to be the same size. */ - ObjectID return_id = task_id; - int64_t *first_bytes = (int64_t *) &return_id; - /* XOR the first bytes of the object ID with the return index. We add one so - * the first return ID is not the same as the task ID. */ - *first_bytes = *first_bytes ^ (return_index + 1); - return return_id; -} - -ObjectID task_compute_put_id(TaskID task_id, int64_t put_index) { - RAY_DCHECK(put_index >= 0); - /* TODO(pcm): This line requires object and task IDs to be the same size. */ - ObjectID put_id = task_id; - int64_t *first_bytes = (int64_t *) &put_id; - /* XOR the first bytes of the object ID with the return index. We add one so - * the first return ID is not the same as the task ID. */ - *first_bytes = *first_bytes ^ (-put_index - 1); - return put_id; -} - -class TaskBuilder { - public: - void Start(UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorHandleID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns) { - driver_id_ = driver_id; - parent_task_id_ = parent_task_id; - parent_counter_ = parent_counter; - actor_creation_id_ = actor_creation_id; - actor_creation_dummy_object_id_ = actor_creation_dummy_object_id; - actor_id_ = actor_id; - actor_handle_id_ = actor_handle_id; - actor_counter_ = actor_counter; - is_actor_checkpoint_method_ = is_actor_checkpoint_method; - function_id_ = function_id; - num_returns_ = num_returns; - - /* Compute hashes. */ - sha256_init(&ctx); - sha256_update(&ctx, (BYTE *) &driver_id, sizeof(driver_id)); - sha256_update(&ctx, (BYTE *) &parent_task_id, sizeof(parent_task_id)); - sha256_update(&ctx, (BYTE *) &parent_counter, sizeof(parent_counter)); - sha256_update(&ctx, (BYTE *) &actor_creation_id, sizeof(actor_creation_id)); - sha256_update(&ctx, (BYTE *) &actor_creation_dummy_object_id, - sizeof(actor_creation_dummy_object_id)); - sha256_update(&ctx, (BYTE *) &actor_id, sizeof(actor_id)); - sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter)); - sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method, - sizeof(is_actor_checkpoint_method)); - sha256_update(&ctx, (BYTE *) &function_id, sizeof(function_id)); - } - - void NextReferenceArgument(ObjectID object_ids[], int num_object_ids) { - args.push_back( - CreateArg(fbb, to_flatbuf(fbb, &object_ids[0], num_object_ids))); - sha256_update(&ctx, (BYTE *) &object_ids[0], - sizeof(object_ids[0]) * num_object_ids); - } - - void NextValueArgument(uint8_t *value, int64_t length) { - auto arg = fbb.CreateString((const char *) value, length); - auto empty_ids = fbb.CreateVectorOfStrings({}); - args.push_back(CreateArg(fbb, empty_ids, arg)); - sha256_update(&ctx, (BYTE *) value, length); - } - - void SetRequiredResource(const std::string &resource_name, double value) { - RAY_CHECK(resource_map_.count(resource_name) == 0); - resource_map_[resource_name] = value; - } - - uint8_t *Finish(int64_t *size) { - /* Add arguments. */ - auto arguments = fbb.CreateVector(args); - /* Update hash. */ - BYTE buff[DIGEST_SIZE]; - sha256_final(&ctx, buff); - TaskID task_id; - RAY_CHECK(sizeof(task_id) <= DIGEST_SIZE); - memcpy(&task_id, buff, sizeof(task_id)); - /* Add return object IDs. */ - std::vector> returns; - for (int64_t i = 0; i < num_returns_; i++) { - ObjectID return_id = task_compute_return_id(task_id, i); - returns.push_back(to_flatbuf(fbb, return_id)); - } - /* Create TaskInfo. */ - auto message = CreateTaskInfo( - fbb, to_flatbuf(fbb, driver_id_), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, parent_task_id_), parent_counter_, - to_flatbuf(fbb, actor_creation_id_), - to_flatbuf(fbb, actor_creation_dummy_object_id_), - to_flatbuf(fbb, actor_id_), to_flatbuf(fbb, actor_handle_id_), - actor_counter_, is_actor_checkpoint_method_, - to_flatbuf(fbb, function_id_), arguments, fbb.CreateVector(returns), - map_to_flatbuf(fbb, resource_map_)); - /* Finish the TaskInfo. */ - fbb.Finish(message); - *size = fbb.GetSize(); - uint8_t *result = (uint8_t *) malloc(*size); - memcpy(result, fbb.GetBufferPointer(), *size); - fbb.Clear(); - args.clear(); - resource_map_.clear(); - return result; - } - - private: - flatbuffers::FlatBufferBuilder fbb; - std::vector> args; - SHA256_CTX ctx; - - /* Data for the builder. */ - UniqueID driver_id_; - TaskID parent_task_id_; - int64_t parent_counter_; - ActorID actor_creation_id_; - ObjectID actor_creation_dummy_object_id_; - ActorID actor_id_; - ActorID actor_handle_id_; - int64_t actor_counter_; - bool is_actor_checkpoint_method_; - FunctionID function_id_; - int64_t num_returns_; - std::unordered_map resource_map_; -}; - -TaskBuilder *make_task_builder(void) { - return new TaskBuilder(); -} - -void free_task_builder(TaskBuilder *builder) { - delete builder; -} - -bool TaskID_equal(TaskID first_id, TaskID second_id) { - return first_id == second_id; -} - -bool TaskID_is_nil(TaskID id) { - return id.is_nil(); -} - -bool ActorID_equal(ActorID first_id, ActorID second_id) { - return first_id == second_id; -} - -bool FunctionID_equal(FunctionID first_id, FunctionID second_id) { - return first_id == second_id; -} - -bool FunctionID_is_nil(FunctionID id) { - return id.is_nil(); -} - -/* Functions for building tasks. */ - -void TaskSpec_start_construct(TaskBuilder *builder, - UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns) { - builder->Start(driver_id, parent_task_id, parent_counter, actor_creation_id, - actor_creation_dummy_object_id, actor_id, actor_handle_id, - actor_counter, is_actor_checkpoint_method, function_id, - num_returns); -} - -TaskSpec *TaskSpec_finish_construct(TaskBuilder *builder, int64_t *size) { - return reinterpret_cast(builder->Finish(size)); -} - -void TaskSpec_args_add_ref(TaskBuilder *builder, - ObjectID object_ids[], - int num_object_ids) { - builder->NextReferenceArgument(&object_ids[0], num_object_ids); -} - -void TaskSpec_args_add_val(TaskBuilder *builder, - uint8_t *value, - int64_t length) { - builder->NextValueArgument(value, length); -} - -void TaskSpec_set_required_resource(TaskBuilder *builder, - const std::string &resource_name, - double value) { - builder->SetRequiredResource(resource_name, value); -} - -/* Functions for reading tasks. */ - -TaskID TaskSpec_task_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->task_id()); -} - -FunctionID TaskSpec_function(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->function_id()); -} - -ActorID TaskSpec_actor_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_id()); -} - -ActorID TaskSpec_actor_handle_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_handle_id()); -} - -bool TaskSpec_is_actor_task(TaskSpec *spec) { - return !TaskSpec_actor_id(spec).is_nil(); -} - -ActorID TaskSpec_actor_creation_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_creation_id()); -} - -ObjectID TaskSpec_actor_creation_dummy_object_id(TaskSpec *spec) { - RAY_CHECK(spec); - // The task must be an actor method. - RAY_CHECK(TaskSpec_is_actor_task(spec)); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_creation_dummy_object_id()); -} - -bool TaskSpec_is_actor_creation_task(TaskSpec *spec) { - return !TaskSpec_actor_creation_id(spec).is_nil(); -} - -int64_t TaskSpec_actor_counter(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return std::abs(message->actor_counter()); -} - -bool TaskSpec_is_actor_checkpoint_method(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->is_actor_checkpoint_method(); -} - -ObjectID TaskSpec_actor_dummy_object(TaskSpec *spec) { - RAY_CHECK(TaskSpec_is_actor_task(spec)); - /* The last return value for actor tasks is the dummy object that - * represents that this task has completed execution. */ - int64_t num_returns = TaskSpec_num_returns(spec); - return TaskSpec_return(spec, num_returns - 1); -} - -UniqueID TaskSpec_driver_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->driver_id()); -} - -TaskID TaskSpec_parent_task_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->parent_task_id()); -} - -int64_t TaskSpec_parent_counter(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->parent_counter(); -} - -int64_t TaskSpec_num_args(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->size(); -} - -int64_t TaskSpec_num_args_by_ref(TaskSpec *spec) { - int64_t num_args = TaskSpec_num_args(spec); - int64_t num_args_by_ref = 0; - for (int64_t i = 0; i < num_args; i++) { - if (TaskSpec_arg_by_ref(spec, i)) { - num_args_by_ref++; - } - } - return num_args_by_ref; -} - -int TaskSpec_arg_id_count(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - auto ids = message->args()->Get(arg_index)->object_ids(); - if (ids == nullptr) { - return 0; - } else { - return ids->size(); - } -} - -ObjectID TaskSpec_arg_id(TaskSpec *spec, int64_t arg_index, int64_t id_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf( - *message->args()->Get(arg_index)->object_ids()->Get(id_index)); -} - -const uint8_t *TaskSpec_arg_val(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return (uint8_t *) message->args()->Get(arg_index)->data()->c_str(); -} - -int64_t TaskSpec_arg_length(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->Get(arg_index)->data()->size(); -} - -int64_t TaskSpec_num_returns(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->returns()->size(); -} - -bool TaskSpec_arg_by_ref(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->Get(arg_index)->object_ids()->size() != 0; -} - -ObjectID TaskSpec_return(TaskSpec *spec, int64_t return_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->returns()->Get(return_index)); -} - -double TaskSpec_get_required_resource(const TaskSpec *spec, - const std::string &resource_name) { - // This is a bit ugly. However it shouldn't be much of a performance issue - // because there shouldn't be many distinct resources in a single task spec. - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - for (size_t i = 0; i < message->required_resources()->size(); i++) { - const ResourcePair *resource_pair = message->required_resources()->Get(i); - if (string_from_flatbuf(*resource_pair->key()) == resource_name) { - return resource_pair->value(); - } - } - return 0; -} - -const std::unordered_map TaskSpec_get_required_resources( - const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return map_from_flatbuf(*message->required_resources()); -} - -TaskSpec *TaskSpec_copy(TaskSpec *spec, int64_t task_spec_size) { - TaskSpec *copy = (TaskSpec *) malloc(task_spec_size); - memcpy(copy, spec, task_spec_size); - return copy; -} - -void TaskSpec_free(TaskSpec *spec) { - free(spec); -} - -TaskExecutionSpec::TaskExecutionSpec( - const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size, - int spillback_count) - : execution_dependencies_(execution_dependencies), - task_spec_size_(task_spec_size), - last_timestamp_(0), - spillback_count_(spillback_count) { - TaskSpec *spec_copy = new TaskSpec[task_spec_size_]; - memcpy(spec_copy, spec, task_spec_size); - spec_ = std::unique_ptr(spec_copy); -} - -TaskExecutionSpec::TaskExecutionSpec( - const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size) - : TaskExecutionSpec(execution_dependencies, spec, task_spec_size, 0) {} - -TaskExecutionSpec::TaskExecutionSpec(TaskExecutionSpec *other) - : execution_dependencies_(other->execution_dependencies_), - task_spec_size_(other->task_spec_size_), - last_timestamp_(other->last_timestamp_), - spillback_count_(other->spillback_count_) { - TaskSpec *spec_copy = new TaskSpec[task_spec_size_]; - memcpy(spec_copy, other->spec_.get(), task_spec_size_); - spec_ = std::unique_ptr(spec_copy); -} - -const std::vector &TaskExecutionSpec::ExecutionDependencies() const { - return execution_dependencies_; -} - -void TaskExecutionSpec::SetExecutionDependencies( - const std::vector &dependencies) { - execution_dependencies_ = dependencies; -} - -int64_t TaskExecutionSpec::SpecSize() const { - return task_spec_size_; -} - -int TaskExecutionSpec::SpillbackCount() const { - return spillback_count_; -} - -void TaskExecutionSpec::IncrementSpillbackCount() { - ++spillback_count_; -} - -int64_t TaskExecutionSpec::LastTimeStamp() const { - return last_timestamp_; -} - -void TaskExecutionSpec::SetLastTimeStamp(int64_t new_timestamp) { - last_timestamp_ = new_timestamp; -} - -TaskSpec *TaskExecutionSpec::Spec() const { - return spec_.get(); -} - -int64_t TaskExecutionSpec::NumDependencies() const { - TaskSpec *spec = Spec(); - int64_t num_dependencies = TaskSpec_num_args(spec); - num_dependencies += execution_dependencies_.size(); - return num_dependencies; -} - -int TaskExecutionSpec::DependencyIdCount(int64_t dependency_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. Find the total number of task arguments so - * that we can index into the correct list. */ - int64_t num_args = TaskSpec_num_args(spec); - if (dependency_index < num_args) { - /* Index into the task arguments. */ - return TaskSpec_arg_id_count(spec, dependency_index); - } else { - /* Index into the execution dependencies. */ - dependency_index -= num_args; - RAY_CHECK((size_t) dependency_index < execution_dependencies_.size()); - /* All elements in the execution dependency list have exactly one ID. */ - return 1; - } -} - -ObjectID TaskExecutionSpec::DependencyId(int64_t dependency_index, - int64_t id_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. Find the total number of task arguments so - * that we can index into the correct list. */ - int64_t num_args = TaskSpec_num_args(spec); - if (dependency_index < num_args) { - /* Index into the task arguments. */ - return TaskSpec_arg_id(spec, dependency_index, id_index); - } else { - /* Index into the execution dependencies. */ - dependency_index -= num_args; - RAY_CHECK((size_t) dependency_index < execution_dependencies_.size()); - return execution_dependencies_[dependency_index]; - } -} - -bool TaskExecutionSpec::DependsOn(ObjectID object_id) const { - // Iterate through the task arguments to see if it contains object_id. - TaskSpec *spec = Spec(); - int64_t num_args = TaskSpec_num_args(spec); - for (int i = 0; i < num_args; ++i) { - int count = TaskSpec_arg_id_count(spec, i); - for (int j = 0; j < count; j++) { - ObjectID arg_id = TaskSpec_arg_id(spec, i, j); - if (arg_id == object_id) { - return true; - } - } - } - // Iterate through the execution dependencies to see if it contains object_id. - for (auto dependency_id : execution_dependencies_) { - if (dependency_id == object_id) { - return true; - } - } - // The requested object ID was not a task argument or an execution dependency. - // This task is not dependent on it. - return false; -} - -bool TaskExecutionSpec::IsStaticDependency(int64_t dependency_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. If the requested dependency index is a task - * argument, then it is a task dependency. */ - int64_t num_args = TaskSpec_num_args(spec); - return (dependency_index < num_args); -} - -/* TASK INSTANCES */ - -Task *Task_alloc(const TaskSpec *spec, - int64_t task_spec_size, - TaskStatus state, - DBClientID local_scheduler_id, - const std::vector &execution_dependencies) { - Task *result = new Task(); - auto execution_spec = - new TaskExecutionSpec(execution_dependencies, spec, task_spec_size); - result->execution_spec = std::unique_ptr(execution_spec); - result->state = state; - result->local_scheduler_id = local_scheduler_id; - return result; -} - -Task *Task_alloc(TaskExecutionSpec &execution_spec, - TaskStatus state, - DBClientID local_scheduler_id) { - Task *result = new Task(); - result->execution_spec = std::unique_ptr( - new TaskExecutionSpec(&execution_spec)); - result->state = state; - result->local_scheduler_id = local_scheduler_id; - return result; -} - -Task *Task_copy(Task *other) { - return Task_alloc(*Task_task_execution_spec(other), other->state, - other->local_scheduler_id); -} - -int64_t Task_size(Task *task_arg) { - return sizeof(Task) - sizeof(TaskSpec) + task_arg->execution_spec->SpecSize(); -} - -TaskStatus Task_state(Task *task) { - return task->state; -} - -void Task_set_state(Task *task, TaskStatus state) { - task->state = state; -} - -DBClientID Task_local_scheduler(Task *task) { - return task->local_scheduler_id; -} - -void Task_set_local_scheduler(Task *task, DBClientID local_scheduler_id) { - task->local_scheduler_id = local_scheduler_id; -} - -TaskExecutionSpec *Task_task_execution_spec(Task *task) { - return task->execution_spec.get(); -} - -TaskID Task_task_id(Task *task) { - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - return TaskSpec_task_id(spec); -} - -void Task_free(Task *task) { - delete task; -} diff --git a/src/common/task.h b/src/common/task.h deleted file mode 100644 index 3984cfdd51195..0000000000000 --- a/src/common/task.h +++ /dev/null @@ -1,609 +0,0 @@ -#ifndef TASK_H -#define TASK_H - -#include - -#include -#include -#include "common.h" - -#include - -#include "format/common_generated.h" - -using namespace ray; - -typedef char TaskSpec; - -class TaskExecutionSpec { - public: - TaskExecutionSpec(const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size); - TaskExecutionSpec(const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size, - int spillback_count); - TaskExecutionSpec(TaskExecutionSpec *execution_spec); - - /// Get the task's execution dependencies. - /// - /// @return A vector of object IDs representing this task's execution - /// dependencies. - const std::vector &ExecutionDependencies() const; - - /// Set the task's execution dependencies. - /// - /// @param dependencies The value to set the execution dependencies to. - /// @return Void. - void SetExecutionDependencies(const std::vector &dependencies); - - /// Get the task spec size. - /// - /// @return The size of the immutable task spec. - int64_t SpecSize() const; - - /// Get the task's spillback count, which tracks the number of times - /// this task was spilled back from local to the global scheduler. - /// - /// @return The spillback count for this task. - int SpillbackCount() const; - - /// Increment the spillback count for this task. - /// - /// @return Void. - void IncrementSpillbackCount(); - - /// Get the task's last timestamp. - /// - /// @return The timestamp when this task was last received for scheduling. - int64_t LastTimeStamp() const; - - /// Set the task's last timestamp to the specified value. - /// - /// @param new_timestamp The new timestamp in millisecond to set the task's - /// time stamp to. Tracks the last time this task entered a local - /// scheduler. - /// @return Void. - void SetLastTimeStamp(int64_t new_timestamp); - - /// Get the task spec. - /// - /// @return A pointer to the immutable task spec. - TaskSpec *Spec() const; - - /// Get the number of dependencies. This comprises the immutable task - /// arguments and the mutable execution dependencies. - /// - /// @return The number of dependencies. - int64_t NumDependencies() const; - - /// Get the number of object IDs at the given dependency index. - /// - /// @param dependency_index The dependency index whose object IDs to count. - /// @return The number of object IDs at the given dependency_index. - int DependencyIdCount(int64_t dependency_index) const; - - /// Get the object ID of a given dependency index. - /// - /// @param dependency_index The index at which we should look up the object - /// ID. - /// @param id_index The index of the object ID. - ObjectID DependencyId(int64_t dependency_index, int64_t id_index) const; - - /// Compute whether the task is dependent on an object ID. - /// - /// @param object_id The object ID that the task may be dependent on. - /// @return bool This returns true if the task is dependent on the given - /// object ID and false otherwise. - bool DependsOn(ObjectID object_id) const; - - /// Returns whether the given dependency index is a static dependency (an - /// argument of the immutable task). - /// - /// @param dependency_index The requested dependency index. - /// @return bool This returns true if the requested dependency index is - /// immutable (an argument of the task). - bool IsStaticDependency(int64_t dependency_index) const; - - private: - /** A list of object IDs representing this task's dependencies at execution - * time. */ - std::vector execution_dependencies_; - /** The size of the task specification for this task. */ - int64_t task_spec_size_; - /** Last time this task was received for scheduling. */ - int64_t last_timestamp_; - /** Number of times this task was spilled back by local schedulers. */ - int spillback_count_; - /** The task specification for this task. */ - std::unique_ptr spec_; -}; - -class TaskBuilder; - -typedef UniqueID FunctionID; - -/** The task ID is a deterministic hash of the function ID that the task - * executes and the argument IDs or argument values. */ -typedef UniqueID TaskID; - -/** The actor ID is the ID of the actor that a task must run on. If the task is - * not run on an actor, then NIL_ACTOR_ID should be used. */ -typedef UniqueID ActorID; - -/** - * Compare two task IDs. - * - * @param first_id The first task ID to compare. - * @param second_id The first task ID to compare. - * @return True if the task IDs are the same and false otherwise. - */ -bool TaskID_equal(TaskID first_id, TaskID second_id); - -/** - * Compare a task ID to the nil ID. - * - * @param id The task ID to compare to nil. - * @return True if the task ID is equal to nil. - */ -bool TaskID_is_nil(TaskID id); - -/** - * Compare two actor IDs. - * - * @param first_id The first actor ID to compare. - * @param second_id The first actor ID to compare. - * @return True if the actor IDs are the same and false otherwise. - */ -bool ActorID_equal(ActorID first_id, ActorID second_id); - -/** - * Compare two function IDs. - * - * @param first_id The first function ID to compare. - * @param second_id The first function ID to compare. - * @return True if the function IDs are the same and false otherwise. - */ -bool FunctionID_equal(FunctionID first_id, FunctionID second_id); - -/** - * Compare a function ID to the nil ID. - * - * @param id The function ID to compare to nil. - * @return True if the function ID is equal to nil. - */ -bool FunctionID_is_nil(FunctionID id); - -/* Construct and modify task specifications. */ - -TaskBuilder *make_task_builder(void); - -void free_task_builder(TaskBuilder *builder); - -/** - * Begin constructing a task_spec. After this is called, the arguments must be - * added to the task_spec and then finish_construct_task_spec must be called. - * - * @param driver_id The ID of the driver whose job is responsible for the - * creation of this task. - * @param parent_task_id The task ID of the task that submitted this task. - * @param parent_counter A counter indicating how many tasks were submitted by - * the parent task prior to this one. - * @param actor_creation_id The actor creation ID of this task. - * @param actor_creation_dummy_object_id The dummy object for the corresponding - * actor creation task, assuming this is an actor method. - * @param actor_id The ID of the actor that this task is for. If it is not an - * actor task, then this if NIL_ACTOR_ID. - * @param actor_handle_id The ID of the actor handle that this task was - * submitted through. If it is not an actor task, or if this is the - * original handle, then this is NIL_ACTOR_ID. - * @param actor_counter A counter indicating how many tasks have been submitted - * to the same actor before this one. - * @param is_actor_checkpoint_method True if this is an actor checkpoint method - * and false otherwise. - * @param function_id The function ID of the function to execute in this task. - * @param num_args The number of arguments that this task has. - * @param num_returns The number of return values that this task has. - * @param args_value_size The total size in bytes of the arguments to this task - ignoring object ID arguments. - * @return The partially constructed task_spec. - */ -void TaskSpec_start_construct(TaskBuilder *B, - UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorHandleID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns); - -/** - * Finish constructing a task_spec. This computes the task ID and the object IDs - * of the task return values. This must be called after all of the arguments - * have been added to the task. - * - * @param spec The task spec whose ID and return object IDs should be computed. - * @return Void. - */ -TaskSpec *TaskSpec_finish_construct(TaskBuilder *builder, int64_t *size); - -/** - * Return the function ID of the task. - * - * @param spec The task_spec in question. - * @return The function ID of the function to execute in this task. - */ -FunctionID TaskSpec_function(TaskSpec *spec); - -/** - * Return the actor ID of the task. - * - * @param spec The task_spec in question. - * @return The actor ID of the actor the task is part of. - */ -ActorID TaskSpec_actor_id(TaskSpec *spec); - -/** - * Return the actor handle ID of the task. - * - * @param spec The task_spec in question. - * @return The ID of the actor handle that the task was submitted through. - */ -ActorID TaskSpec_actor_handle_id(TaskSpec *spec); - -/** - * Return whether this task is for an actor. - * - * @param spec The task_spec in question. - * @return Whether the task is for an actor. - */ -bool TaskSpec_is_actor_task(TaskSpec *spec); - -/// Return whether this task is an actor creation task or not. -/// -/// \param spec The task_spec in question. -/// \return True if this task is an actor creation task and false otherwise. -bool TaskSpec_is_actor_creation_task(TaskSpec *spec); - -/// Return the actor creation ID of the task. The task must be an actor creation -/// task. -/// -/// \param spec The task_spec in question. -/// \return The actor creation ID if this is an actor creation task. -ActorID TaskSpec_actor_creation_id(TaskSpec *spec); - -/// Return the actor creation dummy object ID of the task. The task must be an -/// actor task. -/// -/// \param spec The task_spec in question. -/// \return The actor creation dummy object ID corresponding to this actor task. -ObjectID TaskSpec_actor_creation_dummy_object_id(TaskSpec *spec); - -/** - * Return the actor counter of the task. This starts at 0 and increments by 1 - * every time a new task is submitted to run on the actor. - * - * @param spec The task_spec in question. - * @return The actor counter of the task. - */ -int64_t TaskSpec_actor_counter(TaskSpec *spec); - -/** - * Return whether the task is a checkpoint method execution. - * - * @param spec The task_spec in question. - * @return Whether the task is a checkpoint method. - */ -bool TaskSpec_is_actor_checkpoint_method(TaskSpec *spec); - -/** - * Return an actor task's dummy return value. Dummy objects are used to - * encode an actor's state dependencies in the task graph. The dummy object - * is local if and only if the task that returned it has completed - * execution. - * - * @param spec The task_spec in question. - * @return The dummy object ID that the actor task will return. - */ -ObjectID TaskSpec_actor_dummy_object(TaskSpec *spec); - -/** - * Return the driver ID of the task. - * - * @param spec The task_spec in question. - * @return The driver ID of the task. - */ -UniqueID TaskSpec_driver_id(const TaskSpec *spec); - -/** - * Return the task ID of the parent task. - * - * @param spec The task_spec in question. - * @return The task ID of the parent task. - */ -TaskID TaskSpec_parent_task_id(const TaskSpec *spec); - -/** - * Return the task counter of the parent task. For example, this equals 5 if - * this task was the 6th task submitted by the parent task. - * - * @param spec The task_spec in question. - * @return The task counter of the parent task. - */ -int64_t TaskSpec_parent_counter(TaskSpec *spec); - -/** - * Return the task ID of the task. - * - * @param spec The task_spec in question. - * @return The task ID of the task. - */ -TaskID TaskSpec_task_id(const TaskSpec *spec); - -/** - * Get the number of arguments to this task. - * - * @param spec The task_spec in question. - * @return The number of arguments to this task. - */ -int64_t TaskSpec_num_args(TaskSpec *spec); - -/** - * Get the number of return values expected from this task. - * - * @param spec The task_spec in question. - * @return The number of return values expected from this task. - */ -int64_t TaskSpec_num_returns(TaskSpec *spec); - -/** - * Return true if this argument is passed by reference. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return True if this argument is passed by reference. - */ -bool TaskSpec_arg_by_ref(TaskSpec *spec, int64_t arg_index); - -/** - * Get number of object IDs in a given argument - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return number of object IDs in this argument - */ -int TaskSpec_arg_id_count(TaskSpec *spec, int64_t arg_index); - -/** - * Get a particular argument to this task. This assumes the argument is an - * object ID. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @param id_index The index of the object ID in this arg. - * @return The argument at that index. - */ -ObjectID TaskSpec_arg_id(TaskSpec *spec, int64_t arg_index, int64_t id_index); - -/** - * Get a particular argument to this task. This assumes the argument is a value. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return The argument at that index. - */ -const uint8_t *TaskSpec_arg_val(TaskSpec *spec, int64_t arg_index); - -/** - * Get the number of bytes in a particular argument to this task. This assumes - * the argument is a value. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return The number of bytes in the argument. - */ -int64_t TaskSpec_arg_length(TaskSpec *spec, int64_t arg_index); - -/** - * Set the next task argument. Note that this API only allows you to set the - * arguments in their order of appearance. - * - * @param spec The task_spec in question. - * @param object_ids The object IDs to set the argument to. - * @param num_object_ids number of IDs in this param, usually 1. - * @return The number of task arguments that have been set before this one. This - * is only used for testing. - */ -void TaskSpec_args_add_ref(TaskBuilder *spec, - ObjectID object_ids[], - int num_object_ids); - -/** - * Set the next task argument. Note that this API only allows you to set the - * arguments in their order of appearance. - * - * @param spec The task_spec in question. - * @param The value to set the argument to. - * @param The length of the value to set the argument to. - * @return The number of task arguments that have been set before this one. This - * is only used for testing. - */ -void TaskSpec_args_add_val(TaskBuilder *builder, - uint8_t *value, - int64_t length); - -/** - * Set the value associated to a resource index. - * - * @param spec Task specification. - * @param resource_name Name of the resource in the resource vector. - * @param value Value for the resource. This can be a quantity of this resource - * this task needs or a value for an attribute this task requires. - * @return Void. - */ -void TaskSpec_set_required_resource(TaskBuilder *builder, - const std::string &resource_name, - double value); - -/** - * Get a particular return object ID of a task. - * - * @param spec The task_spec in question. - * @param return_index The index of the return object ID in question. - * @return The relevant return object ID. - */ -ObjectID TaskSpec_return(TaskSpec *data, int64_t return_index); - -/** - * Get the value associated to a resource name. - * - * @param spec Task specification. - * @param resource_name Name of the resource. - * @return How many of this resource the task needs to execute. - */ -double TaskSpec_get_required_resource(const TaskSpec *spec, - const std::string &resource_name); - -/** - * - */ -const std::unordered_map TaskSpec_get_required_resources( - const TaskSpec *spec); - -/** - * Compute the object id associated to a put call. - * - * @param task_id The task id of the parent task that called the put. - * @param put_index The number of put calls in this task so far. - * @return The object ID for the object that was put. - */ -ObjectID task_compute_put_id(TaskID task_id, int64_t put_index); - -/** - * Print the task as a humanly readable string. - * - * @param spec The task_spec in question. - * @return The humanly readable string. - */ -std::string TaskSpec_print(TaskSpec *spec); - -/** - * Create a copy of the task spec. Must be freed with TaskSpec_free after use. - * - * @param spec The task specification that will be copied. - * @param task_spec_size The size of the task specification in bytes. - * @returns Pointer to the copy of the task specification. - */ -TaskSpec *TaskSpec_copy(TaskSpec *spec, int64_t task_spec_size); - -/** - * Free a task_spec. - * - * @param The task_spec in question. - * @return Void. - */ -void TaskSpec_free(TaskSpec *spec); - -/** - * ==== Task ==== - * Contains information about a scheduled task: The task specification, the - * task scheduling state (WAITING, SCHEDULED, QUEUED, RUNNING, DONE), and which - * local scheduler the task is scheduled on. - */ - -/** The scheduling_state can be used as a flag when we are listening - * for an event, for example TASK_WAITING | TASK_SCHEDULED. */ -enum class TaskStatus : uint { - /** The task is waiting to be scheduled. */ - WAITING = 1, - /** The task has been scheduled to a node, but has not been queued yet. */ - SCHEDULED = 2, - /** The task has been queued on a node, where it will wait for its - * dependencies to become ready and a worker to become available. */ - QUEUED = 4, - /** The task is running on a worker. */ - RUNNING = 8, - /** The task is done executing. */ - DONE = 16, - /** The task was not able to finish. */ - LOST = 32, - /** The task will be submitted for reexecution. */ - RECONSTRUCTING = 64, - /** An actor task is cached at a local scheduler and is waiting for the - * corresponding actor to be created. */ - ACTOR_CACHED = 128 -}; - -inline TaskStatus operator|(const TaskStatus &a, const TaskStatus &b) { - uint c = static_cast(a) | static_cast(b); - return static_cast(c); -} - -/** A task is an execution of a task specification. It has a state of execution - * (see scheduling_state) and the ID of the local scheduler it is scheduled on - * or running on. */ - -struct Task { - /** The scheduling state of the task. */ - TaskStatus state; - /** The ID of the local scheduler involved. */ - DBClientID local_scheduler_id; - /** The execution specification for this task. */ - std::unique_ptr execution_spec; -}; - -/** - * Allocate a new task. Must be freed with free_task after use. - * - * @param spec The task spec for the new task. - * @param state The scheduling state for the new task. - * @param local_scheduler_id The ID of the local scheduler that the task is - * scheduled on, if any. - */ -Task *Task_alloc(const TaskSpec *spec, - int64_t task_spec_size, - TaskStatus state, - DBClientID local_scheduler_id, - const std::vector &execution_dependencies); - -Task *Task_alloc(TaskExecutionSpec &execution_spec, - TaskStatus state, - DBClientID local_scheduler_id); - -/** - * Create a copy of the task. Must be freed with Task_free after use. - * - * @param other The task that will be copied. - * @returns Pointer to the copy of the task. - */ -Task *Task_copy(Task *other); - -/** Size of task structure in bytes. */ -int64_t Task_size(Task *task); - -/** The scheduling state of the task. */ -TaskStatus Task_state(Task *task); - -/** Update the schedule state of the task. */ -void Task_set_state(Task *task, TaskStatus state); - -/** Local scheduler this task has been assigned to or is running on. */ -DBClientID Task_local_scheduler(Task *task); - -/** Set the local scheduler ID for this task. */ -void Task_set_local_scheduler(Task *task, DBClientID local_scheduler_id); - -TaskExecutionSpec *Task_task_execution_spec(Task *task); - -/** Task ID of this task. */ -TaskID Task_task_id(Task *task); - -/** Free this task datastructure. */ -void Task_free(Task *task); - -#endif /* TASK_H */ diff --git a/src/common/test/db_tests.cc b/src/common/test/db_tests.cc deleted file mode 100644 index 83585ca66e0f9..0000000000000 --- a/src/common/test/db_tests.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include "event_loop.h" -#include "test_common.h" -#include "example_task.h" -#include "net.h" -#include "state/db.h" -#include "state/db_client_table.h" -#include "state/object_table.h" -#include "state/task_table.h" -#include "state/redis.h" -#include "task.h" - -SUITE(db_tests); - -TaskBuilder *g_task_builder = NULL; - -/* Retry 10 times with an 100ms timeout. */ -const int NUM_RETRIES = 10; -const uint64_t TIMEOUT = 50; - -const char *manager_addr = "127.0.0.1"; -int manager_port1 = 12345; -int manager_port2 = 12346; -char received_addr1[16] = {0}; -int received_port1; -char received_addr2[16] = {0}; -int received_port2; - -typedef struct { int test_number; } user_context; - -const int TEST_NUMBER = 10; - -/* Test if entries have been written to the database. */ - -void lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context) { - DBHandle *db = (DBHandle *) user_context; - RAY_CHECK(manager_ids.size() == 2); - const std::vector managers = - db_client_table_get_ip_addresses(db, manager_ids); - RAY_CHECK(parse_ip_addr_port(managers.at(0).c_str(), received_addr1, - &received_port1) == 0); - RAY_CHECK(parse_ip_addr_port(managers.at(1).c_str(), received_addr2, - &received_port2) == 0); -} - -/* Entry added to database successfully. */ -void add_done_callback(ObjectID object_id, bool success, void *user_context) {} - -/* Test if we got a timeout callback if we couldn't connect database. */ -void timeout_callback(ObjectID object_id, void *context, void *user_data) { - user_context *uc = (user_context *) context; - RAY_CHECK(uc->test_number == TEST_NUMBER); -} - -int64_t timeout_handler(event_loop *loop, int64_t id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -TEST object_table_lookup_test(void) { - event_loop *loop = event_loop_create(); - /* This uses manager_port1. */ - std::vector db_connect_args1; - db_connect_args1.push_back("manager_address"); - db_connect_args1.push_back("127.0.0.1:12345"); - DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - manager_addr, db_connect_args1); - /* This uses manager_port2. */ - std::vector db_connect_args2; - db_connect_args2.push_back("manager_address"); - db_connect_args2.push_back("127.0.0.1:12346"); - DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - manager_addr, db_connect_args2); - db_attach(db1, loop, false); - db_attach(db2, loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = timeout_callback, - }; - object_table_add(db1, id, 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, NULL); - object_table_add(db2, id, 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, NULL); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - object_table_lookup(db1, id, &retry, lookup_done_callback, db1); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - ASSERT_STR_EQ(&received_addr1[0], manager_addr); - ASSERT((received_port1 == manager_port1 && received_port2 == manager_port2) || - (received_port2 == manager_port1 && received_port1 == manager_port2)); - - db_disconnect(db1); - db_disconnect(db2); - - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - PASS(); -} - -int task_table_test_callback_called = 0; -Task *task_table_test_task; - -void task_table_test_fail_callback(UniqueID id, - void *context, - void *user_data) { - event_loop *loop = (event_loop *) user_data; - event_loop_stop(loop); -} - -int64_t task_table_delayed_add_task(event_loop *loop, - int64_t id, - void *context) { - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = task_table_test_fail_callback, - }; - task_table_add_task(db, Task_copy(task_table_test_task), &retry, NULL, - (void *) loop); - return EVENT_LOOP_TIMER_DONE; -} - -void task_table_test_callback(Task *callback_task, void *user_data) { - task_table_test_callback_called = 1; - RAY_CHECK(Task_state(callback_task) == TaskStatus::SCHEDULED); - RAY_CHECK(Task_size(callback_task) == Task_size(task_table_test_task)); - RAY_CHECK(Task_equals(callback_task, task_table_test_task)); - event_loop *loop = (event_loop *) user_data; - event_loop_stop(loop); -} - -TEST task_table_test(void) { - task_table_test_callback_called = 0; - event_loop *loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - DBClientID local_scheduler_id = DBClientID::from_random(); - TaskExecutionSpec spec = example_task_execution_spec(1, 1); - task_table_test_task = - Task_alloc(spec, TaskStatus::SCHEDULED, local_scheduler_id); - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = task_table_test_fail_callback, - }; - task_table_subscribe(db, local_scheduler_id, TaskStatus::SCHEDULED, - task_table_test_callback, (void *) loop, &retry, NULL, - (void *) loop); - event_loop_add_timer( - loop, 200, (event_loop_timer_handler) task_table_delayed_add_task, db); - event_loop_run(loop); - Task_free(task_table_test_task); - db_disconnect(db); - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - ASSERT(task_table_test_callback_called); - PASS(); -} - -int num_test_callback_called = 0; - -void task_table_all_test_callback(Task *task, void *user_data) { - num_test_callback_called += 1; -} - -TEST task_table_all_test(void) { - event_loop *loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - TaskExecutionSpec spec = example_task_execution_spec(1, 1); - /* Schedule two tasks on different local local schedulers. */ - Task *task1 = - Task_alloc(spec, TaskStatus::SCHEDULED, DBClientID::from_random()); - Task *task2 = - Task_alloc(spec, TaskStatus::SCHEDULED, DBClientID::from_random()); - RetryInfo retry = { - .num_retries = NUM_RETRIES, .timeout = TIMEOUT, .fail_callback = NULL, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::SCHEDULED, - task_table_all_test_callback, NULL, &retry, NULL, NULL); - event_loop_add_timer(loop, 50, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - /* TODO(pcm): Get rid of this sleep once the robust pubsub is implemented. */ - task_table_add_task(db, task1, &retry, NULL, NULL); - task_table_add_task(db, task2, &retry, NULL, NULL); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - db_disconnect(db); - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - ASSERT(num_test_callback_called == 2); - PASS(); -} - -TEST unique_client_id_test(void) { - const int num_conns = 100; - - DBClientID ids[num_conns]; - DBHandle *db; - for (int i = 0; i < num_conns; ++i) { - db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - ids[i] = get_db_client_id(db); - db_disconnect(db); - } - for (int i = 0; i < num_conns; ++i) { - for (int j = 0; j < i; ++j) { - ASSERT(!(ids[i] == ids[j])); - } - } - PASS(); -} - -SUITE(db_tests) { - RUN_REDIS_TEST(object_table_lookup_test); - RUN_REDIS_TEST(task_table_test); - RUN_REDIS_TEST(task_table_all_test); - RUN_REDIS_TEST(unique_client_id_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(db_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/example_task.h b/src/common/test/example_task.h deleted file mode 100644 index f90cab68f6d95..0000000000000 --- a/src/common/test/example_task.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef EXAMPLE_TASK_H -#define EXAMPLE_TASK_H - -#include "task.h" - -extern TaskBuilder *g_task_builder; - -const int64_t arg_value_size = 1000; - -static inline TaskExecutionSpec example_task_execution_spec_with_args( - int64_t num_args, - int64_t num_returns, - ObjectID arg_ids[]) { - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskSpec_start_construct(g_task_builder, UniqueID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, num_returns); - for (int64_t i = 0; i < num_args; ++i) { - ObjectID arg_id; - if (arg_ids == NULL) { - arg_id = ObjectID::from_random(); - } else { - arg_id = arg_ids[i]; - } - TaskSpec_args_add_ref(g_task_builder, &arg_id, 1); - } - int64_t task_spec_size; - TaskSpec *spec = TaskSpec_finish_construct(g_task_builder, &task_spec_size); - std::vector execution_dependencies; - auto execution_spec = - TaskExecutionSpec(execution_dependencies, spec, task_spec_size); - TaskSpec_free(spec); - return execution_spec; -} - -static inline TaskExecutionSpec example_task_execution_spec( - int64_t num_args, - int64_t num_returns) { - return example_task_execution_spec_with_args(num_args, num_returns, NULL); -} - -static inline Task *example_task_with_args(int64_t num_args, - int64_t num_returns, - TaskStatus task_state, - ObjectID arg_ids[]) { - TaskExecutionSpec spec = - example_task_execution_spec_with_args(num_args, num_returns, arg_ids); - Task *instance = Task_alloc(spec, task_state, UniqueID::nil()); - return instance; -} - -static inline Task *example_task(int64_t num_args, - int64_t num_returns, - TaskStatus task_state) { - TaskExecutionSpec spec = example_task_execution_spec(num_args, num_returns); - Task *instance = Task_alloc(spec, task_state, UniqueID::nil()); - return instance; -} - -static inline bool Task_equals(Task *task1, Task *task2) { - if (task1->state != task2->state) { - return false; - } - if (!(task1->local_scheduler_id == task2->local_scheduler_id)) { - return false; - } - auto execution_spec1 = Task_task_execution_spec(task1); - auto execution_spec2 = Task_task_execution_spec(task2); - if (execution_spec1->SpecSize() != execution_spec2->SpecSize()) { - return false; - } - return memcmp(execution_spec1->Spec(), execution_spec2->Spec(), - execution_spec1->SpecSize()) == 0; -} - -#endif /* EXAMPLE_TASK_H */ diff --git a/src/common/test/io_tests.cc b/src/common/test/io_tests.cc deleted file mode 100644 index 092ca97b7d56d..0000000000000 --- a/src/common/test/io_tests.cc +++ /dev/null @@ -1,114 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include -#include - -#include "io.h" - -SUITE(io_tests); - -TEST ipc_socket_test(void) { -#ifndef _WIN32 - const char *socket_pathname = "/tmp/test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - const char *test_string = "hello world"; - const char *test_bytes = "another string"; - pid_t pid = fork(); - if (pid == 0) { - close(socket_fd); - socket_fd = connect_ipc_sock(socket_pathname); - ASSERT(socket_fd >= 0); - write_log_message(socket_fd, test_string); - write_message(socket_fd, - static_cast(CommonMessageType::LOG_MESSAGE), - strlen(test_bytes), (uint8_t *) test_bytes); - close(socket_fd); - exit(0); - } else { - int client_fd = accept_client(socket_fd); - ASSERT(client_fd >= 0); - char *message = read_log_message(client_fd); - ASSERT(message != NULL); - ASSERT_STR_EQ(test_string, message); - free(message); - int64_t type; - int64_t len; - uint8_t *bytes; - read_message(client_fd, &type, &len, &bytes); - ASSERT(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - ASSERT(memcmp(test_bytes, bytes, len) == 0); - free(bytes); - close(client_fd); - close(socket_fd); - unlink(socket_pathname); - } -#endif - PASS(); -} - -TEST long_ipc_socket_test(void) { -#ifndef _WIN32 - const char *socket_pathname = "/tmp/long-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - std::stringstream test_string_ss; - for (int i = 0; i < 10000; i++) { - test_string_ss << "hello world "; - } - std::string test_string = test_string_ss.str(); - const char *test_bytes = "another string"; - pid_t pid = fork(); - if (pid == 0) { - close(socket_fd); - socket_fd = connect_ipc_sock(socket_pathname); - ASSERT(socket_fd >= 0); - write_log_message(socket_fd, test_string.c_str()); - write_message(socket_fd, - static_cast(CommonMessageType::LOG_MESSAGE), - strlen(test_bytes), (uint8_t *) test_bytes); - close(socket_fd); - exit(0); - } else { - int client_fd = accept_client(socket_fd); - ASSERT(client_fd >= 0); - char *message = read_log_message(client_fd); - ASSERT(message != NULL); - ASSERT_STR_EQ(test_string.c_str(), message); - free(message); - int64_t type; - int64_t len; - uint8_t *bytes; - read_message(client_fd, &type, &len, &bytes); - ASSERT(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - ASSERT(memcmp(test_bytes, bytes, len) == 0); - free(bytes); - close(client_fd); - close(socket_fd); - unlink(socket_pathname); - } - -#endif - PASS(); -} - -SUITE(io_tests) { - RUN_TEST(ipc_socket_test); - RUN_TEST(long_ipc_socket_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(io_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/object_table_tests.cc b/src/common/test/object_table_tests.cc deleted file mode 100644 index 0599724386069..0000000000000 --- a/src/common/test/object_table_tests.cc +++ /dev/null @@ -1,919 +0,0 @@ -#include "greatest.h" - -#include "event_loop.h" -#include "example_task.h" -#include "test_common.h" -#include "common.h" -#include "state/db_client_table.h" -#include "state/object_table.h" -#include "state/redis.h" - -#include - -SUITE(object_table_tests); - -static event_loop *g_loop; -TaskBuilder *g_task_builder = NULL; - -/* ==== Test adding and looking up metadata ==== */ - -int new_object_failed = 0; -int new_object_succeeded = 0; -ObjectID new_object_id; -Task *new_object_task; -TaskSpec *new_object_task_spec; -TaskID new_object_task_id; - -void new_object_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - new_object_failed = 1; - event_loop_stop(g_loop); -} - -/* === Test adding an object with an associated task === */ - -void new_object_done_callback(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context) { - new_object_succeeded = 1; - RAY_CHECK(object_id == new_object_id); - RAY_CHECK(task_id == new_object_task_id); - event_loop_stop(g_loop); -} - -void new_object_lookup_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(object_id == new_object_id); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - DBHandle *db = (DBHandle *) user_context; - result_table_lookup(db, new_object_id, &retry, new_object_done_callback, - NULL); -} - -void new_object_task_callback(TaskID task_id, void *user_context) { - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - DBHandle *db = (DBHandle *) user_context; - result_table_add(db, new_object_id, new_object_task_id, false, &retry, - new_object_lookup_callback, (void *) db); -} - -void task_table_subscribe_done(TaskID task_id, void *user_context) { - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = NULL, - }; - DBHandle *db = (DBHandle *) user_context; - task_table_add_task(db, Task_copy(new_object_task), &retry, - new_object_task_callback, db); -} - -TEST new_object_test(void) { - new_object_failed = 0; - new_object_succeeded = 0; - new_object_id = ObjectID::from_random(); - new_object_task = example_task(1, 1, TaskStatus::WAITING); - new_object_task_spec = Task_task_execution_spec(new_object_task)->Spec(); - new_object_task_id = TaskSpec_task_id(new_object_task_spec); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, task_table_subscribe_done, db); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(new_object_succeeded); - ASSERT(!new_object_failed); - PASS(); -} - -/* === Test adding an object without an associated task === */ - -void new_object_no_task_callback(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context) { - new_object_succeeded = 1; - RAY_CHECK(task_id.is_nil()); - event_loop_stop(g_loop); -} - -TEST new_object_no_task_test(void) { - new_object_failed = 0; - new_object_succeeded = 0; - new_object_id = ObjectID::from_random(); - new_object_task_id = TaskID::from_random(); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - result_table_lookup(db, new_object_id, &retry, new_object_no_task_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(new_object_succeeded); - ASSERT(!new_object_failed); - PASS(); -} - -/* ==== Test if operations time out correctly ==== */ - -/* === Test lookup timeout === */ - -const char *lookup_timeout_context = "lookup_timeout"; -int lookup_failed = 0; - -void lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_fail_callback(UniqueID id, void *user_context, void *user_data) { - lookup_failed = 1; - RAY_CHECK(user_context == (void *) lookup_timeout_context); - event_loop_stop(g_loop); -} - -TEST lookup_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = lookup_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, lookup_done_callback, - (void *) lookup_timeout_context); - /* Disconnect the database to see if the lookup times out. */ - close(db->context->c.fd); - for (auto context : db->contexts) { - close(context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_failed); - PASS(); -} - -/* === Test add timeout === */ - -const char *add_timeout_context = "add_timeout"; -int add_failed = 0; - -void add_done_callback(ObjectID object_id, bool success, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void add_fail_callback(UniqueID id, void *user_context, void *user_data) { - add_failed = 1; - RAY_CHECK(user_context == (void *) add_timeout_context); - event_loop_stop(g_loop); -} - -TEST add_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = add_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, (void *) add_timeout_context); - /* Disconnect the database to see if the lookup times out. */ - close(db->context->c.fd); - for (auto context : db->contexts) { - close(context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_failed); - PASS(); -} - -/* === Test subscribe timeout === */ - -int subscribe_failed = 0; - -void subscribe_done_callback(ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { - subscribe_failed = 1; - event_loop_stop(g_loop); -} - -TEST subscribe_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_fail_callback, - }; - object_table_subscribe_to_notifications(db, false, subscribe_done_callback, - NULL, &retry, NULL, NULL); - /* Disconnect the database to see if the lookup times out. */ - close(db->subscribe_context->c.fd); - for (auto subscribe_context : db->subscribe_contexts) { - close(subscribe_context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_failed); - PASS(); -} - -/* ==== Test if the retry is working correctly ==== */ - -int64_t reconnect_context_callback(event_loop *loop, - int64_t timer_id, - void *context) { - DBHandle *db = (DBHandle *) context; - /* Reconnect to redis. This is not reconnecting the pub/sub channel. */ - redisAsyncFree(db->context); - redisFree(db->sync_context); - db->context = redisAsyncConnect("127.0.0.1", 6379); - db->context->data = (void *) db; - db->sync_context = redisConnect("127.0.0.1", 6379); - /* Re-attach the database to the event loop (the file descriptor changed). */ - db_attach(db, loop, true); - RAY_LOG(DEBUG) << "Reconnected to Redis"; - return EVENT_LOOP_TIMER_DONE; -} - -int64_t terminate_event_loop_callback(event_loop *loop, - int64_t timer_id, - void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -/* === Test lookup retry === */ - -const char *lookup_retry_context = "lookup_retry"; -int lookup_retry_succeeded = 0; - -void lookup_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -/* === Test add retry === */ - -const char *add_retry_context = "add_retry"; -int add_retry_succeeded = 0; - -/* === Test add then lookup retry === */ - -void add_lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *context) { - DBHandle *db = (DBHandle *) context; - RAY_CHECK(manager_ids.size() == 1); - const std::vector managers = - db_client_table_get_ip_addresses(db, manager_ids); - RAY_CHECK(managers.at(0) == "127.0.0.1:11235"); - lookup_retry_succeeded = 1; -} - -void add_lookup_callback(ObjectID object_id, bool success, void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, add_lookup_done_callback, - (void *) db); -} - -TEST add_lookup_test(void) { - g_loop = event_loop_create(); - lookup_retry_succeeded = 0; - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11235"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, true); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_lookup_callback, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_retry_succeeded); - PASS(); -} - -/* === Test add, remove, then lookup === */ -void add_remove_lookup_done_callback( - ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - RAY_CHECK(context == (void *) lookup_retry_context); - RAY_CHECK(manager_vector.size() == 0); - lookup_retry_succeeded = 1; -} - -void add_remove_lookup_callback(ObjectID object_id, - bool success, - void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, - add_remove_lookup_done_callback, - (void *) lookup_retry_context); -} - -void add_remove_callback(ObjectID object_id, bool success, void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_remove(db, UniqueID::nil(), NULL, &retry, - add_remove_lookup_callback, (void *) db); -} - -TEST add_remove_lookup_test(void) { - g_loop = event_loop_create(); - lookup_retry_succeeded = 0; - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, true); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_remove_callback, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_retry_succeeded); - PASS(); -} - -/* ==== Test if late succeed is working correctly ==== */ - -/* === Test lookup late succeed === */ - -const char *lookup_late_context = "lookup_late"; -int lookup_late_failed = 0; - -void lookup_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) lookup_late_context); - lookup_late_failed = 1; -} - -void lookup_late_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST lookup_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = lookup_late_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, lookup_late_done_callback, - (void *) lookup_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_late_failed); - PASS(); -} - -/* === Test add late succeed === */ - -const char *add_late_context = "add_late"; -int add_late_failed = 0; - -void add_late_fail_callback(UniqueID id, void *user_context, void *user_data) { - RAY_CHECK(user_context == (void *) add_late_context); - add_late_failed = 1; -} - -void add_late_done_callback(ObjectID object_id, - bool success, - void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST add_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, .timeout = 0, .fail_callback = add_late_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_late_done_callback, (void *) add_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_late_failed); - PASS(); -} - -/* === Test subscribe late succeed === */ - -const char *subscribe_late_context = "subscribe_late"; -int subscribe_late_failed = 0; - -void subscribe_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) subscribe_late_context); - subscribe_late_failed = 1; -} - -void subscribe_late_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = subscribe_late_fail_callback, - }; - object_table_subscribe_to_notifications(db, false, NULL, NULL, &retry, - subscribe_late_done_callback, - (void *) subscribe_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_late_failed); - PASS(); -} - -/* === Test subscribe object available succeed === */ - -const char *subscribe_success_context = "subscribe_success"; -int subscribe_success_done = 0; -int subscribe_success_succeeded = 0; -ObjectID subscribe_id; - -void subscribe_success_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -void subscribe_success_done_callback( - ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *user_context) { - RetryInfo retry = { - .num_retries = 0, .timeout = 750, .fail_callback = NULL, - }; - object_table_add((DBHandle *) user_context, subscribe_id, 0, - (unsigned char *) NIL_DIGEST, &retry, NULL, NULL); - subscribe_success_done = 1; -} - -void subscribe_success_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - RAY_CHECK(user_context == (void *) subscribe_success_context); - RAY_CHECK(object_id == subscribe_id); - RAY_CHECK(manager_vector.size() == 1); - subscribe_success_succeeded = 1; -} - -TEST subscribe_success_test(void) { - g_loop = event_loop_create(); - - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - subscribe_id = ObjectID::from_random(); - - RetryInfo retry = { - .num_retries = 0, - .timeout = 100, - .fail_callback = subscribe_success_fail_callback, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_success_object_available_callback, - (void *) subscribe_success_context, &retry, - subscribe_success_done_callback, (void *) db); - - ObjectID object_ids[1] = {subscribe_id}; - object_table_request_notifications(db, 1, object_ids, &retry); - - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - - ASSERT(subscribe_success_done); - ASSERT(subscribe_success_succeeded); - PASS(); -} - -/* Test if subscribe succeeds if the object is already present. */ -typedef struct { - const char *teststr; - int64_t data_size; -} subscribe_object_present_context_t; - -const char *subscribe_object_present_str = "subscribe_object_present"; -int subscribe_object_present_succeeded = 0; - -void subscribe_object_present_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - subscribe_object_present_context_t *ctx = - (subscribe_object_present_context_t *) user_context; - RAY_CHECK(ctx->data_size == data_size); - RAY_CHECK(strcmp(subscribe_object_present_str, ctx->teststr) == 0); - subscribe_object_present_succeeded = 1; - RAY_CHECK(manager_vector.size() == 1); -} - -void fatal_fail_callback(UniqueID id, void *user_context, void *user_data) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_object_present_test(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t myctx = {subscribe_object_present_str, - data_size}; - - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = fatal_fail_callback, - }; - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - object_table_subscribe_to_notifications( - db, false, subscribe_object_present_object_available_callback, - (void *) &myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to create do the add and subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_object_present_succeeded == 1); - PASS(); -} - -/* Test if subscribe is not called if object is not present. */ - -const char *subscribe_object_not_present_context = - "subscribe_object_not_present"; - -void subscribe_object_not_present_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - /* This should not be called. */ - RAY_CHECK(0); -} - -TEST subscribe_object_not_present_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_object_not_present_object_available_callback, - (void *) subscribe_object_not_present_context, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - PASS(); -} - -/* Test if subscribe is called if object becomes available later. */ - -const char *subscribe_object_available_later_context = - "subscribe_object_available_later"; -int subscribe_object_available_later_succeeded = 0; - -void subscribe_object_available_later_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - subscribe_object_present_context_t *myctx = - (subscribe_object_present_context_t *) user_context; - RAY_CHECK(myctx->data_size == data_size); - RAY_CHECK(strcmp(myctx->teststr, subscribe_object_available_later_context) == - 0); - /* Make sure the callback is only called once. */ - subscribe_object_available_later_succeeded += 1; - RAY_CHECK(manager_vector.size() == 1); -} - -TEST subscribe_object_available_later_test(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t *myctx = - (subscribe_object_present_context_t *) malloc( - sizeof(subscribe_object_present_context_t)); - myctx->teststr = subscribe_object_available_later_context; - myctx->data_size = data_size; - - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_object_available_later_object_available_callback, - (void *) myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - ASSERT_EQ(subscribe_object_available_later_succeeded, 0); - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the object table add. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT_EQ(subscribe_object_available_later_succeeded, 1); - /* Reset the global variable before exiting this unit test. */ - subscribe_object_available_later_succeeded = 0; - free(myctx); - PASS(); -} - -TEST subscribe_object_available_subscribe_all(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t myctx = { - subscribe_object_available_later_context, data_size}; - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, true, subscribe_object_available_later_object_available_callback, - (void *) &myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - /* At this point we don't expect any object notifications received. */ - ASSERT_EQ(subscribe_object_available_later_succeeded, 0); - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - /* Install handler to terminate event loop after 750ms. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the object table add. */ - event_loop_run(g_loop); - /* At this point we assume that object table add completed. */ - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - /* Assert that the object table add completed and notification callback fired. - */ - printf("subscribe_all object info test: callback fired: %d times\n", - subscribe_object_available_later_succeeded); - fflush(stdout); - ASSERT_EQ(subscribe_object_available_later_succeeded, 1); - /* Reset the global variable before exiting this unit test. */ - subscribe_object_available_later_succeeded = 0; - PASS(); -} - -SUITE(object_table_tests) { - RUN_REDIS_TEST(new_object_test); - RUN_REDIS_TEST(new_object_no_task_test); - // RUN_REDIS_TEST(lookup_timeout_test); - // RUN_REDIS_TEST(add_timeout_test); - // RUN_REDIS_TEST(subscribe_timeout_test); - RUN_REDIS_TEST(add_lookup_test); - RUN_REDIS_TEST(add_remove_lookup_test); - // RUN_REDIS_TEST(lookup_late_test); - // RUN_REDIS_TEST(add_late_test); - // RUN_REDIS_TEST(subscribe_late_test); - RUN_REDIS_TEST(subscribe_success_test); - RUN_REDIS_TEST(subscribe_object_not_present_test); - RUN_REDIS_TEST(subscribe_object_available_later_test); - RUN_REDIS_TEST(subscribe_object_available_subscribe_all); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(object_table_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/redis_tests.cc b/src/common/test/redis_tests.cc deleted file mode 100644 index 7db7ae2ee26e9..0000000000000 --- a/src/common/test/redis_tests.cc +++ /dev/null @@ -1,238 +0,0 @@ -#include "greatest.h" - -#include -#include - -#include - -#include "event_loop.h" -#include "state/db.h" -#include "state/redis.h" -#include "io.h" -#include "logging.h" -#include "test_common.h" - -SUITE(redis_tests); - -const char *test_set_format = "SET %s %s"; -const char *test_get_format = "GET %s"; -const char *test_key = "foo"; -const char *test_value = "bar"; -std::vector connections; - -void write_formatted_log_message(int socket_fd, const char *format, ...) { - va_list ap; - - /* Get cmd size */ - va_start(ap, format); - size_t cmd_size = vsnprintf(nullptr, 0, format, ap) + 1; - va_end(ap); - - /* Print va to cmd */ - char cmd[cmd_size]; - va_start(ap, format); - vsnprintf(cmd, cmd_size, format, ap); - va_end(ap); - - write_log_message(socket_fd, cmd); -} - -int async_redis_socket_test_callback_called = 0; - -void async_redis_socket_test_callback(redisAsyncContext *ac, - void *r, - void *privdata) { - async_redis_socket_test_callback_called = 1; - redisContext *context = redisConnect("127.0.0.1", 6379); - redisReply *reply = - (redisReply *) redisCommand(context, test_get_format, test_key); - redisFree(context); - RAY_CHECK(reply != NULL); - if (strcmp(reply->str, test_value)) { - freeReplyObject(reply); - RAY_CHECK(0); - } - freeReplyObject(reply); -} - -TEST redis_socket_test(void) { - const char *socket_pathname = "/tmp/redis-test-socket"; - redisContext *context = redisConnect("127.0.0.1", 6379); - ASSERT(context != NULL); - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - write_formatted_log_message(client_fd, test_set_format, test_key, test_value); - - int server_fd = accept_client(socket_fd); - char *cmd = read_log_message(server_fd); - close(client_fd); - close(server_fd); - close(socket_fd); - unlink(socket_pathname); - - redisReply *reply = (redisReply *) redisCommand(context, cmd, 0, 0); - freeReplyObject(reply); - reply = (redisReply *) redisCommand(context, "GET %s", test_key); - ASSERT(reply != NULL); - ASSERT_STR_EQ(reply->str, test_value); - freeReplyObject(reply); - - free(cmd); - redisFree(context); - PASS(); -} - -void redis_read_callback(event_loop *loop, int fd, void *context, int events) { - DBHandle *db = (DBHandle *) context; - char *cmd = read_log_message(fd); - redisAsyncCommand(db->context, async_redis_socket_test_callback, NULL, cmd); - free(cmd); -} - -void redis_accept_callback(event_loop *loop, - int socket_fd, - void *context, - int events) { - int accept_fd = accept_client(socket_fd); - RAY_CHECK(accept_fd >= 0); - connections.push_back(accept_fd); - event_loop_add_file(loop, accept_fd, EVENT_LOOP_READ, redis_read_callback, - context); -} - -int timeout_handler(event_loop *loop, timer_id timer_id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -TEST async_redis_socket_test(void) { - event_loop *loop = event_loop_create(); - - /* Start IPC channel. */ - const char *socket_pathname = "/tmp/async-redis-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - connections.push_back(socket_fd); - - /* Start connection to Redis. */ - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "test_process", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - - /* Send a command to the Redis process. */ - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - connections.push_back(client_fd); - write_formatted_log_message(client_fd, test_set_format, test_key, test_value); - - event_loop_add_file(loop, client_fd, EVENT_LOOP_READ, redis_read_callback, - db); - event_loop_add_file(loop, socket_fd, EVENT_LOOP_READ, redis_accept_callback, - db); - event_loop_add_timer(loop, 100, timeout_handler, NULL); - event_loop_run(loop); - - ASSERT(async_redis_socket_test_callback_called); - - db_disconnect(db); - event_loop_destroy(loop); - - for (int const &p : connections) { - close(p); - } - unlink(socket_pathname); - connections.clear(); - PASS(); -} - -int logging_test_callback_called = 0; - -void logging_test_callback(redisAsyncContext *ac, void *r, void *privdata) { - logging_test_callback_called = 1; - redisContext *context = redisConnect("127.0.0.1", 6379); - redisReply *reply = (redisReply *) redisCommand(context, "KEYS %s", "log:*"); - redisFree(context); - RAY_CHECK(reply != NULL); - RAY_CHECK(reply->elements > 0); - freeReplyObject(reply); -} - -void logging_read_callback(event_loop *loop, - int fd, - void *context, - int events) { - DBHandle *conn = (DBHandle *) context; - char *cmd = read_log_message(fd); - redisAsyncCommand(conn->context, logging_test_callback, NULL, cmd, - (char *) conn->client.data(), sizeof(conn->client)); - free(cmd); -} - -void logging_accept_callback(event_loop *loop, - int socket_fd, - void *context, - int events) { - int accept_fd = accept_client(socket_fd); - RAY_CHECK(accept_fd >= 0); - connections.push_back(accept_fd); - event_loop_add_file(loop, accept_fd, EVENT_LOOP_READ, logging_read_callback, - context); -} - -TEST logging_test(void) { - event_loop *loop = event_loop_create(); - - /* Start IPC channel. */ - const char *socket_pathname = "/tmp/logging-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - connections.push_back(socket_fd); - - /* Start connection to Redis. */ - DBHandle *conn = db_connect(std::string("127.0.0.1"), 6379, "test_process", - "127.0.0.1", std::vector()); - db_attach(conn, loop, false); - - /* Send a command to the Redis process. */ - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - connections.push_back(client_fd); - RayLogger *logger = RayLogger_init("worker", RAY_LOG_INFO, 0, &client_fd); - RayLogger_log(logger, RAY_LOG_INFO, "TEST", "Message"); - - event_loop_add_file(loop, socket_fd, EVENT_LOOP_READ, logging_accept_callback, - conn); - event_loop_add_file(loop, client_fd, EVENT_LOOP_READ, logging_read_callback, - conn); - event_loop_add_timer(loop, 100, timeout_handler, NULL); - event_loop_run(loop); - - ASSERT(logging_test_callback_called); - - RayLogger_free(logger); - db_disconnect(conn); - event_loop_destroy(loop); - for (int const &p : connections) { - close(p); - } - unlink(socket_pathname); - connections.clear(); - PASS(); -} - -SUITE(redis_tests) { - RUN_REDIS_TEST(redis_socket_test); - RUN_REDIS_TEST(async_redis_socket_test); - RUN_REDIS_TEST(logging_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(redis_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/run_tests.sh b/src/common/test/run_tests.sh deleted file mode 100644 index 5ccb1e3f92ffa..0000000000000 --- a/src/common/test/run_tests.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -# Cause the script to exit if a single command fails. -set -ex - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi - sleep 1s -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -if [ -z "$RAY_USE_NEW_GCS" ]; then - ./src/common/db_tests - ./src/common/io_tests - ./src/common/task_tests - ./src/common/redis_tests - ./src/common/task_table_tests - ./src/common/object_table_tests -fi - -./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown diff --git a/src/common/test/run_valgrind.sh b/src/common/test/run_valgrind.sh deleted file mode 100644 index 418a91366e132..0000000000000 --- a/src/common/test/run_valgrind.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -set -x - -# Cause the script to exit if a single command fails. -set -e - -if [ -z "$RAY_USE_NEW_GCS" ]; then - # Start the Redis shards. - ./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & - ./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & - sleep 1s - # Register the shard location with the primary shard. - ./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 - ./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/db_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/io_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/task_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/redis_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/task_table_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/object_table_tests - ./src/common/thirdparty/redis/src/redis-cli shutdown - ./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -fi diff --git a/src/common/test/task_table_tests.cc b/src/common/test/task_table_tests.cc deleted file mode 100644 index f94aca3b132c4..0000000000000 --- a/src/common/test/task_table_tests.cc +++ /dev/null @@ -1,460 +0,0 @@ -#include "greatest.h" - -#include "event_loop.h" -#include "example_task.h" -#include "test_common.h" -#include "common.h" -#include "state/object_table.h" -#include "state/redis.h" - -#include -#include - -SUITE(task_table_tests); - -event_loop *g_loop; -TaskBuilder *g_task_builder = NULL; - -/* ==== Test operations in non-failure scenario ==== */ - -/* === A lookup of a task not in the table === */ - -TaskID lookup_nil_id; -int lookup_nil_success = 0; -const char *lookup_nil_context = "lookup_nil"; - -void lookup_nil_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_nil_success_callback(Task *task, void *context) { - lookup_nil_success = 1; - RAY_CHECK(task == NULL); - RAY_CHECK(context == (void *) lookup_nil_context); - event_loop_stop(g_loop); -} - -TEST lookup_nil_test(void) { - lookup_nil_id = TaskID::from_random(); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = lookup_nil_fail_callback, - }; - task_table_get_task(db, lookup_nil_id, &retry, lookup_nil_success_callback, - (void *) lookup_nil_context); - /* Disconnect the database to see if the lookup times out. */ - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_nil_success); - PASS(); -} - -/* === A lookup of a task after it's added returns the same spec === */ - -int add_success = 0; -int lookup_success = 0; -Task *add_lookup_task; -const char *add_lookup_context = "add_lookup"; - -void add_lookup_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_success_callback(Task *task, void *context) { - lookup_success = 1; - RAY_CHECK(Task_equals(task, add_lookup_task)); - event_loop_stop(g_loop); -} - -void add_success_callback(TaskID task_id, void *context) { - add_success = 1; - RAY_CHECK(TaskID_equal(task_id, Task_task_id(add_lookup_task))); - - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - task_table_get_task(db, task_id, &retry, lookup_success_callback, - (void *) add_lookup_context); -} - -void subscribe_success_callback(TaskID task_id, void *context) { - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - task_table_add_task(db, Task_copy(add_lookup_task), &retry, - add_success_callback, (void *) db); -} - -TEST add_lookup_test(void) { - add_lookup_task = example_task(1, 1, TaskStatus::WAITING); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - /* Wait for subscription to succeed before adding the task. */ - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_success_callback, (void *) db); - /* Disconnect the database to see if the lookup times out. */ - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_success); - ASSERT(lookup_success); - PASS(); -} - -/* ==== Test if operations time out correctly ==== */ - -/* === Test subscribe timeout === */ - -const char *subscribe_timeout_context = "subscribe_timeout"; -int subscribe_failed = 0; - -void subscribe_done_callback(TaskID task_id, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { - subscribe_failed = 1; - RAY_CHECK(user_context == (void *) subscribe_timeout_context); - event_loop_stop(g_loop); -} - -TEST subscribe_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_done_callback, - (void *) subscribe_timeout_context); - /* Disconnect the database to see if the subscribe times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_failed); - PASS(); -} - -/* === Test publish timeout === */ - -const char *publish_timeout_context = "publish_timeout"; -int publish_failed = 0; - -void publish_done_callback(TaskID task_id, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void publish_fail_callback(UniqueID id, void *user_context, void *user_data) { - publish_failed = 1; - RAY_CHECK(user_context == (void *) publish_timeout_context); - event_loop_stop(g_loop); -} - -TEST publish_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = publish_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, NULL, NULL); - task_table_add_task(db, task, &retry, publish_done_callback, - (void *) publish_timeout_context); - /* Disconnect the database to see if the publish times out. */ - close(db->context->c.fd); - for (size_t i = 0; i < db->contexts.size(); ++i) { - close(db->contexts[i]->c.fd); - } - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_failed); - PASS(); -} - -/* ==== Test if the retry is working correctly ==== */ - -int64_t reconnect_db_callback(event_loop *loop, - int64_t timer_id, - void *context) { - DBHandle *db = (DBHandle *) context; - /* Reconnect to redis. */ - redisAsyncFree(db->subscribe_context); - db->subscribe_context = redisAsyncConnect("127.0.0.1", 6379); - db->subscribe_context->data = (void *) db; - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - redisAsyncFree(db->subscribe_contexts[i]); - db->subscribe_contexts[i] = redisAsyncConnect("127.0.0.1", 6380 + i); - db->subscribe_contexts[i]->data = (void *) db; - } - /* Re-attach the database to the event loop (the file descriptor changed). */ - db_attach(db, loop, true); - return EVENT_LOOP_TIMER_DONE; -} - -int64_t terminate_event_loop_callback(event_loop *loop, - int64_t timer_id, - void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -/* === Test subscribe retry === */ - -const char *subscribe_retry_context = "subscribe_retry"; -int subscribe_retry_succeeded = 0; - -void subscribe_retry_done_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(user_context == (void *) subscribe_retry_context); - subscribe_retry_succeeded = 1; -} - -void subscribe_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -TEST subscribe_retry_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_retry_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_retry_done_callback, - (void *) subscribe_retry_context); - /* Disconnect the database to see if the subscribe times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - /* Install handler for reconnecting the database. */ - event_loop_add_timer(g_loop, 150, - (event_loop_timer_handler) reconnect_db_callback, db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_retry_succeeded); - PASS(); -} - -/* === Test publish retry === */ - -const char *publish_retry_context = "publish_retry"; -int publish_retry_succeeded = 0; - -void publish_retry_done_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(user_context == (void *) publish_retry_context); - publish_retry_succeeded = 1; -} - -void publish_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -TEST publish_retry_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = publish_retry_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, NULL, NULL); - task_table_add_task(db, task, &retry, publish_retry_done_callback, - (void *) publish_retry_context); - /* Disconnect the database to see if the publish times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - /* Install handler for reconnecting the database. */ - event_loop_add_timer(g_loop, 150, - (event_loop_timer_handler) reconnect_db_callback, db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_retry_succeeded); - PASS(); -} - -/* ==== Test if late succeed is working correctly ==== */ - -/* === Test subscribe late succeed === */ - -const char *subscribe_late_context = "subscribe_late"; -int subscribe_late_failed = 0; - -void subscribe_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) subscribe_late_context); - subscribe_late_failed = 1; -} - -void subscribe_late_done_callback(TaskID task_id, void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = subscribe_late_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_late_done_callback, - (void *) subscribe_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_late_failed); - PASS(); -} - -/* === Test publish late succeed === */ - -const char *publish_late_context = "publish_late"; -int publish_late_failed = 0; - -void publish_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) publish_late_context); - publish_late_failed = 1; -} - -void publish_late_done_callback(TaskID task_id, void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST publish_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = publish_late_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - NULL, NULL, NULL); - task_table_add_task(db, task, &retry, publish_late_done_callback, - (void *) publish_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_late_failed); - PASS(); -} - -SUITE(task_table_tests) { - RUN_REDIS_TEST(lookup_nil_test); - RUN_REDIS_TEST(add_lookup_test); - // RUN_REDIS_TEST(subscribe_timeout_test); - // RUN_REDIS_TEST(publish_timeout_test); - // RUN_REDIS_TEST(subscribe_retry_test); - // RUN_REDIS_TEST(publish_retry_test); - // RUN_REDIS_TEST(subscribe_late_test); - // RUN_REDIS_TEST(publish_late_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(task_table_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/task_tests.cc b/src/common/test/task_tests.cc deleted file mode 100644 index 2277912e7dec0..0000000000000 --- a/src/common/test/task_tests.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include "common.h" -#include "test_common.h" -#include "task.h" -#include "io.h" - -SUITE(task_tests); - -TEST task_test(void) { - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskBuilder *builder = make_task_builder(); - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 2); - - UniqueID arg1 = UniqueID::from_random(); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "hello", 5); - UniqueID arg2 = UniqueID::from_random(); - TaskSpec_args_add_ref(builder, &arg2, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "world", 5); - /* Finish constructing the spec. This constructs the task ID and the - * return IDs. */ - int64_t size; - TaskSpec *spec = TaskSpec_finish_construct(builder, &size); - - /* Check that the spec was constructed as expected. */ - ASSERT(TaskSpec_num_args(spec) == 4); - ASSERT(TaskSpec_num_returns(spec) == 2); - ASSERT(FunctionID_equal(TaskSpec_function(spec), func_id)); - ASSERT(TaskSpec_arg_id(spec, 0, 0) == arg1); - ASSERT(memcmp(TaskSpec_arg_val(spec, 1), (uint8_t *) "hello", - TaskSpec_arg_length(spec, 1)) == 0); - ASSERT(TaskSpec_arg_id(spec, 2, 0) == arg2); - ASSERT(memcmp(TaskSpec_arg_val(spec, 3), (uint8_t *) "world", - TaskSpec_arg_length(spec, 3)) == 0); - - TaskSpec_free(spec); - free_task_builder(builder); - PASS(); -} - -TEST deterministic_ids_test(void) { - TaskBuilder *builder = make_task_builder(); - /* Define the inputs to the task construction. */ - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - UniqueID arg1 = UniqueID::from_random(); - uint8_t *arg2 = (uint8_t *) "hello world"; - - /* Construct a first task. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size1; - TaskSpec *spec1 = TaskSpec_finish_construct(builder, &size1); - - /* Construct a second identical task. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size2; - TaskSpec *spec2 = TaskSpec_finish_construct(builder, &size2); - - /* Check that these tasks have the same task IDs and the same return IDs. */ - ASSERT(TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec2))); - ASSERT(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 0)); - ASSERT(TaskSpec_return(spec1, 1) == TaskSpec_return(spec2, 1)); - ASSERT(TaskSpec_return(spec1, 2) == TaskSpec_return(spec2, 2)); - /* Check that the return IDs are all distinct. */ - ASSERT(!(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 1))); - ASSERT(!(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 2))); - ASSERT(!(TaskSpec_return(spec1, 1) == TaskSpec_return(spec2, 2))); - - /* Create more tasks that are only mildly different. */ - - /* Construct a task with a different parent task ID. */ - TaskSpec_start_construct(builder, DriverID::nil(), TaskID::from_random(), 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size3; - TaskSpec *spec3 = TaskSpec_finish_construct(builder, &size3); - - /* Construct a task with a different parent counter. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 1, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size4; - TaskSpec *spec4 = TaskSpec_finish_construct(builder, &size4); - - /* Construct a task with a different function ID. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, FunctionID::from_random(), - 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size5; - TaskSpec *spec5 = TaskSpec_finish_construct(builder, &size5); - - /* Construct a task with a different object ID argument. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - ObjectID object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size6; - TaskSpec *spec6 = TaskSpec_finish_construct(builder, &size6); - - /* Construct a task with a different value argument. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "hello_world", 11); - int64_t size7; - TaskSpec *spec7 = TaskSpec_finish_construct(builder, &size7); - - /* Check that the task IDs are all distinct from the original. */ - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec3))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec4))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec5))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec6))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec7))); - - /* Check that the return object IDs are distinct from the originals. */ - TaskSpec *specs[6] = {spec1, spec3, spec4, spec5, spec6, spec7}; - for (int task_index1 = 0; task_index1 < 6; ++task_index1) { - for (int return_index1 = 0; return_index1 < 3; ++return_index1) { - for (int task_index2 = 0; task_index2 < 6; ++task_index2) { - for (int return_index2 = 0; return_index2 < 3; ++return_index2) { - if (task_index1 != task_index2 && return_index1 != return_index2) { - ASSERT(!(TaskSpec_return(specs[task_index1], return_index1) == - TaskSpec_return(specs[task_index2], return_index2))); - } - } - } - } - } - - TaskSpec_free(spec1); - TaskSpec_free(spec2); - TaskSpec_free(spec3); - TaskSpec_free(spec4); - TaskSpec_free(spec5); - TaskSpec_free(spec6); - TaskSpec_free(spec7); - free_task_builder(builder); - PASS(); -} - -TEST send_task(void) { - TaskBuilder *builder = make_task_builder(); - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 2); - ObjectID object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "Hello", 5); - TaskSpec_args_add_val(builder, (uint8_t *) "World", 5); - object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - int64_t size; - TaskSpec *spec = TaskSpec_finish_construct(builder, &size); - int fd[2]; - socketpair(AF_UNIX, SOCK_STREAM, 0, fd); - write_message(fd[0], static_cast(CommonMessageType::SUBMIT_TASK), - size, (uint8_t *) spec); - int64_t type; - int64_t length; - uint8_t *message; - read_message(fd[1], &type, &length, &message); - TaskSpec *result = (TaskSpec *) message; - ASSERT(static_cast(type) == - CommonMessageType::SUBMIT_TASK); - ASSERT(memcmp(spec, result, size) == 0); - TaskSpec_free(spec); - free(result); - free_task_builder(builder); - PASS(); -} - -SUITE(task_tests) { - RUN_TEST(task_test); - RUN_TEST(deterministic_ids_test); - RUN_TEST(send_task); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(task_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/test_common.h b/src/common/test/test_common.h deleted file mode 100644 index 03984e6f22490..0000000000000 --- a/src/common/test/test_common.h +++ /dev/null @@ -1,91 +0,0 @@ -#ifndef TEST_COMMON_H -#define TEST_COMMON_H - -#include - -#include -#include -#include - -#include "common.h" -#include "io.h" -#include "hiredis/hiredis.h" -#include "state/redis.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -/* I/O helper methods to retry binding to sockets. */ -static inline std::string bind_ipc_sock_retry(const char *socket_name_format, - int *fd) { - std::string socket_name; - for (int num_retries = 0; num_retries < 5; ++num_retries) { - RAY_LOG(INFO) << "trying to find plasma socket (attempt " << num_retries - << ")"; - size_t size = std::snprintf(nullptr, 0, socket_name_format, rand()) + 1; - char socket_name_c_str[size]; - std::snprintf(socket_name_c_str, size, socket_name_format, rand()); - socket_name = std::string(socket_name_c_str); - - *fd = bind_ipc_sock(socket_name.c_str(), true); - if (*fd < 0) { - /* Sleep for 100ms. */ - usleep(100000); - continue; - } - break; - } - return socket_name; -} - -static inline int bind_inet_sock_retry(int *fd) { - int port = -1; - for (int num_retries = 0; num_retries < 5; ++num_retries) { - port = 10000 + rand() % 40000; - *fd = bind_inet_sock(port, true); - if (*fd < 0) { - /* Sleep for 100ms. */ - usleep(100000); - continue; - } - break; - } - return port; -} - -/* Flush redis. */ -static inline void flushall_redis(void) { - /* Flush the primary shard. */ - redisContext *context = redisConnect("127.0.0.1", 6379); - std::vector db_shards_addresses; - std::vector db_shards_ports; - get_redis_shards(context, db_shards_addresses, db_shards_ports); - freeReplyObject(redisCommand(context, "FLUSHALL")); - /* Readd the shard locations. */ - freeReplyObject(redisCommand(context, "SET NumRedisShards %d", - db_shards_addresses.size())); - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - freeReplyObject(redisCommand(context, "RPUSH RedisShards %s:%d", - db_shards_addresses[i].c_str(), - db_shards_ports[i])); - } - redisFree(context); - - /* Flush the remaining shards. */ - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - context = redisConnect(db_shards_addresses[i].c_str(), db_shards_ports[i]); - freeReplyObject(redisCommand(context, "FLUSHALL")); - redisFree(context); - } -} - -/* Cleanup method for running tests with the greatest library. - * Runs the test, then clears the Redis database. */ -#define RUN_REDIS_TEST(test) \ - flushall_redis(); \ - RUN_TEST(test); \ - flushall_redis(); - -#endif /* TEST_COMMON */ diff --git a/src/common/thirdparty/download_thirdparty.bat b/src/common/thirdparty/download_thirdparty.bat deleted file mode 100644 index 988592f83af62..0000000000000 --- a/src/common/thirdparty/download_thirdparty.bat +++ /dev/null @@ -1,15 +0,0 @@ -@SetLocal - @Echo Off - @PushD "%~dp0" - git submodule update --init --jobs="%NUMBER_OF_PROCESSORS%" - @If Not Exist "python\.git" git clone "https://github.com/austinsc/python.git" - Call :GitApply "python" "%CD%/patches/windows/python-pyconfig.patch" - Call :GitApply "redis-windows" "%CD%/patches/windows/redis.patch" - @PopD -@EndLocal -@GoTo :EOF - -:GitApply - @REM Check if patch already applied by attempting to apply it in reverse; if not, then force-reapply it - git -C "%~1" apply "%~2" -R --check 2> NUL || git -C "%~1" apply "%~2" --3way 2> NUL || git -C "%~1" reset --hard && git -C "%~1" apply "%~2" --3way -@GoTo :EOF diff --git a/src/common/thirdparty/greatest.h b/src/common/thirdparty/greatest.h deleted file mode 100644 index eb34ff4263ece..0000000000000 --- a/src/common/thirdparty/greatest.h +++ /dev/null @@ -1,1023 +0,0 @@ -/* - * Copyright (c) 2011-2016 Scott Vokes - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef GREATEST_H -#define GREATEST_H - -#ifdef __cplusplus -extern "C" { -#endif - -/* 1.2.1 */ -#define GREATEST_VERSION_MAJOR 1 -#define GREATEST_VERSION_MINOR 2 -#define GREATEST_VERSION_PATCH 1 - -/* A unit testing system for C, contained in 1 file. - * It doesn't use dynamic allocation or depend on anything - * beyond ANSI C89. - * - * An up-to-date version can be found at: - * https://github.com/silentbicycle/greatest/ - */ - - -/********************************************************************* - * Minimal test runner template - *********************************************************************/ -#if 0 -#include "greatest.h" -TEST foo_should_foo(void) { - PASS(); -} -static void setup_cb(void *data) { - printf("setup callback for each test case\n"); -} -static void teardown_cb(void *data) { - printf("teardown callback for each test case\n"); -} -SUITE(suite) { - /* Optional setup/teardown callbacks which will be run before/after - * every test case. If using a test suite, they will be cleared when - * the suite finishes. */ - SET_SETUP(setup_cb, voidp_to_callback_data); - SET_TEARDOWN(teardown_cb, voidp_to_callback_data); - RUN_TEST(foo_should_foo); -} -/* Add definitions that need to be in the test runner's main file. */ -GREATEST_MAIN_DEFS(); -/* Set up, run suite(s) of tests, report pass/fail/skip stats. */ -int run_tests(void) { - GREATEST_INIT(); /* init. greatest internals */ - /* List of suites to run (if any). */ - RUN_SUITE(suite); - /* Tests can also be run directly, without using test suites. */ - RUN_TEST(foo_should_foo); - GREATEST_PRINT_REPORT(); /* display results */ - return greatest_all_passed(); -} -/* main(), for a standalone command-line test runner. - * This replaces run_tests above, and adds command line option - * handling and exiting with a pass/fail status. */ -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); /* init & parse command-line args */ - RUN_SUITE(suite); - GREATEST_MAIN_END(); /* display results */ -} -#endif -/*********************************************************************/ - - -#include -#include -#include -#include - -/*********** - * Options * - ***********/ - -/* Default column width for non-verbose output. */ -#ifndef GREATEST_DEFAULT_WIDTH -#define GREATEST_DEFAULT_WIDTH 72 -#endif - -/* FILE *, for test logging. */ -#ifndef GREATEST_STDOUT -#define GREATEST_STDOUT stdout -#endif - -/* Remove GREATEST_ prefix from most commonly used symbols? */ -#ifndef GREATEST_USE_ABBREVS -#define GREATEST_USE_ABBREVS 1 -#endif - -/* Set to 0 to disable all use of setjmp/longjmp. */ -#ifndef GREATEST_USE_LONGJMP -#define GREATEST_USE_LONGJMP 1 -#endif - -#if GREATEST_USE_LONGJMP -#include -#endif - -/* Set to 0 to disable all use of time.h / clock(). */ -#ifndef GREATEST_USE_TIME -#define GREATEST_USE_TIME 1 -#endif - -#if GREATEST_USE_TIME -#include -#endif - -/* Floating point type, for ASSERT_IN_RANGE. */ -#ifndef GREATEST_FLOAT -#define GREATEST_FLOAT double -#define GREATEST_FLOAT_FMT "%g" -#endif - -/********* - * Types * - *********/ - -/* Info for the current running suite. */ -typedef struct greatest_suite_info { - unsigned int tests_run; - unsigned int passed; - unsigned int failed; - unsigned int skipped; - -#if GREATEST_USE_TIME - /* timers, pre/post running suite and individual tests */ - clock_t pre_suite; - clock_t post_suite; - clock_t pre_test; - clock_t post_test; -#endif -} greatest_suite_info; - -/* Type for a suite function. */ -typedef void (greatest_suite_cb)(void); - -/* Types for setup/teardown callbacks. If non-NULL, these will be run - * and passed the pointer to their additional data. */ -typedef void (greatest_setup_cb)(void *udata); -typedef void (greatest_teardown_cb)(void *udata); - -/* Type for an equality comparison between two pointers of the same type. - * Should return non-0 if equal, otherwise 0. - * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ -typedef int greatest_equal_cb(const void *exp, const void *got, void *udata); - -/* Type for a callback that prints a value pointed to by T. - * Return value has the same meaning as printf's. - * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ -typedef int greatest_printf_cb(const void *t, void *udata); - -/* Callbacks for an arbitrary type; needed for type-specific - * comparisons via GREATEST_ASSERT_EQUAL_T[m].*/ -typedef struct greatest_type_info { - greatest_equal_cb *equal; - greatest_printf_cb *print; -} greatest_type_info; - -typedef struct greatest_memory_cmp_env { - const unsigned char *exp; - const unsigned char *got; - size_t size; -} greatest_memory_cmp_env; - -/* Callbacks for string and raw memory types. */ -extern greatest_type_info greatest_type_info_string; -extern greatest_type_info greatest_type_info_memory; - -typedef enum { - GREATEST_FLAG_FIRST_FAIL = 0x01, - GREATEST_FLAG_LIST_ONLY = 0x02 -} greatest_flag_t; - -/* Struct containing all test runner state. */ -typedef struct greatest_run_info { - unsigned char flags; - unsigned char verbosity; - unsigned int tests_run; /* total test count */ - - /* overall pass/fail/skip counts */ - unsigned int passed; - unsigned int failed; - unsigned int skipped; - unsigned int assertions; - - /* currently running test suite */ - greatest_suite_info suite; - - /* info to print about the most recent failure */ - const char *fail_file; - unsigned int fail_line; - const char *msg; - - /* current setup/teardown hooks and userdata */ - greatest_setup_cb *setup; - void *setup_udata; - greatest_teardown_cb *teardown; - void *teardown_udata; - - /* formatting info for ".....s...F"-style output */ - unsigned int col; - unsigned int width; - - /* only run a specific suite or test */ - const char *suite_filter; - const char *test_filter; - -#if GREATEST_USE_TIME - /* overall timers */ - clock_t begin; - clock_t end; -#endif - -#if GREATEST_USE_LONGJMP - jmp_buf jump_dest; -#endif -} greatest_run_info; - -struct greatest_report_t { - /* overall pass/fail/skip counts */ - unsigned int passed; - unsigned int failed; - unsigned int skipped; - unsigned int assertions; -}; - -/* Global var for the current testing context. - * Initialized by GREATEST_MAIN_DEFS(). */ -extern greatest_run_info greatest_info; - -/* Type for ASSERT_ENUM_EQ's ENUM_STR argument. */ -typedef const char *greatest_enum_str_fun(int value); - -/********************** - * Exported functions * - **********************/ - -/* These are used internally by greatest. */ -void greatest_do_pass(const char *name); -void greatest_do_fail(const char *name); -void greatest_do_skip(const char *name); -int greatest_pre_test(const char *name); -void greatest_post_test(const char *name, int res); -void greatest_usage(const char *name); -int greatest_do_assert_equal_t(const void *exp, const void *got, - greatest_type_info *type_info, void *udata); - -/* These are part of the public greatest API. */ -void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata); -void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, void *udata); -int greatest_all_passed(void); -void greatest_set_test_filter(const char *name); -void greatest_set_suite_filter(const char *name); -void greatest_get_report(struct greatest_report_t *report); -unsigned int greatest_get_verbosity(void); -void greatest_set_verbosity(unsigned int verbosity); -void greatest_set_flag(greatest_flag_t flag); - - -/******************** -* Language Support * -********************/ - -/* If __VA_ARGS__ (C99) is supported, allow parametric testing -* without needing to manually manage the argument struct. */ -#if __STDC_VERSION__ >= 19901L || _MSC_VER >= 1800 -#define GREATEST_VA_ARGS -#endif - - -/********** - * Macros * - **********/ - -/* Define a suite. */ -#define GREATEST_SUITE(NAME) void NAME(void); void NAME(void) - -/* Declare a suite, provided by another compilation unit. */ -#define GREATEST_SUITE_EXTERN(NAME) void NAME(void) - -/* Start defining a test function. - * The arguments are not included, to allow parametric testing. */ -#define GREATEST_TEST static enum greatest_test_res - -/* PASS/FAIL/SKIP result from a test. Used internally. */ -typedef enum greatest_test_res { - GREATEST_TEST_RES_PASS = 0, - GREATEST_TEST_RES_FAIL = -1, - GREATEST_TEST_RES_SKIP = 1 -} greatest_test_res; - -/* Run a suite. */ -#define GREATEST_RUN_SUITE(S_NAME) greatest_run_suite(S_NAME, #S_NAME) - -/* Run a test in the current suite. */ -#define GREATEST_RUN_TEST(TEST) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - enum greatest_test_res res = GREATEST_SAVE_CONTEXT(); \ - if (res == GREATEST_TEST_RES_PASS) { \ - res = TEST(); \ - } \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) - -/* Ignore a test, don't warn about it being unused. */ -#define GREATEST_IGNORE_TEST(TEST) (void)TEST - -/* Run a test in the current suite with one void * argument, - * which can be a pointer to a struct with multiple arguments. */ -#define GREATEST_RUN_TEST1(TEST, ENV) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - int res = TEST(ENV); \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) - -#ifdef GREATEST_VA_ARGS -#define GREATEST_RUN_TESTp(TEST, ...) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - int res = TEST(__VA_ARGS__); \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) -#endif - - -/* Check if the test runner is in verbose mode. */ -#define GREATEST_IS_VERBOSE() ((greatest_info.verbosity) > 0) -#define GREATEST_LIST_ONLY() \ - (greatest_info.flags & GREATEST_FLAG_LIST_ONLY) -#define GREATEST_FIRST_FAIL() \ - (greatest_info.flags & GREATEST_FLAG_FIRST_FAIL) -#define GREATEST_FAILURE_ABORT() \ - (greatest_info.suite.failed > 0 && GREATEST_FIRST_FAIL()) - -/* Message-less forms of tests defined below. */ -#define GREATEST_PASS() GREATEST_PASSm(NULL) -#define GREATEST_FAIL() GREATEST_FAILm(NULL) -#define GREATEST_SKIP() GREATEST_SKIPm(NULL) -#define GREATEST_ASSERT(COND) \ - GREATEST_ASSERTm(#COND, COND) -#define GREATEST_ASSERT_OR_LONGJMP(COND) \ - GREATEST_ASSERT_OR_LONGJMPm(#COND, COND) -#define GREATEST_ASSERT_FALSE(COND) \ - GREATEST_ASSERT_FALSEm(#COND, COND) -#define GREATEST_ASSERT_EQ(EXP, GOT) \ - GREATEST_ASSERT_EQm(#EXP " != " #GOT, EXP, GOT) -#define GREATEST_ASSERT_EQ_FMT(EXP, GOT, FMT) \ - GREATEST_ASSERT_EQ_FMTm(#EXP " != " #GOT, EXP, GOT, FMT) -#define GREATEST_ASSERT_IN_RANGE(EXP, GOT, TOL) \ - GREATEST_ASSERT_IN_RANGEm(#EXP " != " #GOT " +/- " #TOL, EXP, GOT, TOL) -#define GREATEST_ASSERT_EQUAL_T(EXP, GOT, TYPE_INFO, UDATA) \ - GREATEST_ASSERT_EQUAL_Tm(#EXP " != " #GOT, EXP, GOT, TYPE_INFO, UDATA) -#define GREATEST_ASSERT_STR_EQ(EXP, GOT) \ - GREATEST_ASSERT_STR_EQm(#EXP " != " #GOT, EXP, GOT) -#define GREATEST_ASSERT_STRN_EQ(EXP, GOT, SIZE) \ - GREATEST_ASSERT_STRN_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) -#define GREATEST_ASSERT_MEM_EQ(EXP, GOT, SIZE) \ - GREATEST_ASSERT_MEM_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) -#define GREATEST_ASSERT_ENUM_EQ(EXP, GOT, ENUM_STR) \ - GREATEST_ASSERT_ENUM_EQm(#EXP " != " #GOT, EXP, GOT, ENUM_STR) - -/* The following forms take an additional message argument first, - * to be displayed by the test runner. */ - -/* Fail if a condition is not true, with message. */ -#define GREATEST_ASSERTm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if (!(COND)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if a condition is not true, longjmping out of test. */ -#define GREATEST_ASSERT_OR_LONGJMPm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if (!(COND)) { GREATEST_FAIL_WITH_LONGJMPm(MSG); } \ - } while (0) - -/* Fail if a condition is not false, with message. */ -#define GREATEST_ASSERT_FALSEm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if ((COND)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if EXP != GOT (equality comparison by ==). */ -#define GREATEST_ASSERT_EQm(MSG, EXP, GOT) \ - do { \ - greatest_info.assertions++; \ - if ((EXP) != (GOT)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if EXP != GOT (equality comparison by ==). - * Warning: EXP and GOT will be evaluated more than once on failure. */ -#define GREATEST_ASSERT_EQ_FMTm(MSG, EXP, GOT, FMT) \ - do { \ - const char *greatest_FMT = ( FMT ); \ - greatest_info.assertions++; \ - if ((EXP) != (GOT)) { \ - fprintf(GREATEST_STDOUT, "\nExpected: "); \ - fprintf(GREATEST_STDOUT, greatest_FMT, EXP); \ - fprintf(GREATEST_STDOUT, "\n Got: "); \ - fprintf(GREATEST_STDOUT, greatest_FMT, GOT); \ - fprintf(GREATEST_STDOUT, "\n"); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) - -/* Fail if EXP is not equal to GOT, printing enum IDs. */ -#define GREATEST_ASSERT_ENUM_EQm(MSG, EXP, GOT, ENUM_STR) \ - do { \ - int greatest_EXP = (int)(EXP); \ - int greatest_GOT = (int)(GOT); \ - greatest_enum_str_fun *greatest_ENUM_STR = ENUM_STR; \ - if (greatest_EXP != greatest_GOT) { \ - fprintf(GREATEST_STDOUT, "\nExpected: %s", \ - greatest_ENUM_STR(greatest_EXP)); \ - fprintf(GREATEST_STDOUT, "\n Got: %s\n", \ - greatest_ENUM_STR(greatest_GOT)); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) \ - -/* Fail if GOT not in range of EXP +|- TOL. */ -#define GREATEST_ASSERT_IN_RANGEm(MSG, EXP, GOT, TOL) \ - do { \ - GREATEST_FLOAT greatest_EXP = (EXP); \ - GREATEST_FLOAT greatest_GOT = (GOT); \ - GREATEST_FLOAT greatest_TOL = (TOL); \ - greatest_info.assertions++; \ - if ((greatest_EXP > greatest_GOT && \ - greatest_EXP - greatest_GOT > greatest_TOL) || \ - (greatest_EXP < greatest_GOT && \ - greatest_GOT - greatest_EXP > greatest_TOL)) { \ - fprintf(GREATEST_STDOUT, \ - "\nExpected: " GREATEST_FLOAT_FMT \ - " +/- " GREATEST_FLOAT_FMT \ - "\n Got: " GREATEST_FLOAT_FMT \ - "\n", \ - greatest_EXP, greatest_TOL, greatest_GOT); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) - -/* Fail if EXP is not equal to GOT, according to strcmp. */ -#define GREATEST_ASSERT_STR_EQm(MSG, EXP, GOT) \ - do { \ - GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ - &greatest_type_info_string, NULL); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to strcmp. */ -#define GREATEST_ASSERT_STRN_EQm(MSG, EXP, GOT, SIZE) \ - do { \ - size_t size = SIZE; \ - GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ - &greatest_type_info_string, &size); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to memcmp. */ -#define GREATEST_ASSERT_MEM_EQm(MSG, EXP, GOT, SIZE) \ - do { \ - greatest_memory_cmp_env env; \ - env.exp = (const unsigned char *)EXP; \ - env.got = (const unsigned char *)GOT; \ - env.size = SIZE; \ - GREATEST_ASSERT_EQUAL_Tm(MSG, env.exp, env.got, \ - &greatest_type_info_memory, &env); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to a comparison - * callback in TYPE_INFO. If they are not equal, optionally use a - * print callback in TYPE_INFO to print them. */ -#define GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, TYPE_INFO, UDATA) \ - do { \ - greatest_type_info *type_info = (TYPE_INFO); \ - greatest_info.assertions++; \ - if (!greatest_do_assert_equal_t(EXP, GOT, \ - type_info, UDATA)) { \ - if (type_info == NULL || type_info->equal == NULL) { \ - GREATEST_FAILm("type_info->equal callback missing!"); \ - } else { \ - GREATEST_FAILm(MSG); \ - } \ - } \ - } while (0) \ - -/* Pass. */ -#define GREATEST_PASSm(MSG) \ - do { \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_PASS; \ - } while (0) - -/* Fail. */ -#define GREATEST_FAILm(MSG) \ - do { \ - greatest_info.fail_file = __FILE__; \ - greatest_info.fail_line = __LINE__; \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_FAIL; \ - } while (0) - -/* Optional GREATEST_FAILm variant that longjmps. */ -#if GREATEST_USE_LONGJMP -#define GREATEST_FAIL_WITH_LONGJMP() GREATEST_FAIL_WITH_LONGJMPm(NULL) -#define GREATEST_FAIL_WITH_LONGJMPm(MSG) \ - do { \ - greatest_info.fail_file = __FILE__; \ - greatest_info.fail_line = __LINE__; \ - greatest_info.msg = MSG; \ - longjmp(greatest_info.jump_dest, GREATEST_TEST_RES_FAIL); \ - } while (0) -#endif - -/* Skip the current test. */ -#define GREATEST_SKIPm(MSG) \ - do { \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_SKIP; \ - } while (0) - -/* Check the result of a subfunction using ASSERT, etc. */ -#define GREATEST_CHECK_CALL(RES) \ - do { \ - enum greatest_test_res greatest_RES = RES; \ - if (greatest_RES != GREATEST_TEST_RES_PASS) { \ - return greatest_RES; \ - } \ - } while (0) \ - -#if GREATEST_USE_TIME -#define GREATEST_SET_TIME(NAME) \ - NAME = clock(); \ - if (NAME == (clock_t) -1) { \ - fprintf(GREATEST_STDOUT, \ - "clock error: %s\n", #NAME); \ - exit(EXIT_FAILURE); \ - } - -#define GREATEST_CLOCK_DIFF(C1, C2) \ - fprintf(GREATEST_STDOUT, " (%lu ticks, %.3f sec)", \ - (long unsigned int) (C2) - (long unsigned int)(C1), \ - (double)((C2) - (C1)) / (1.0 * (double)CLOCKS_PER_SEC)) -#else -#define GREATEST_SET_TIME(UNUSED) -#define GREATEST_CLOCK_DIFF(UNUSED1, UNUSED2) -#endif - -#if GREATEST_USE_LONGJMP -#define GREATEST_SAVE_CONTEXT() \ - /* setjmp returns 0 (GREATEST_TEST_RES_PASS) on first call */ \ - /* so the test runs, then RES_FAIL from FAIL_WITH_LONGJMP. */ \ - ((enum greatest_test_res)(setjmp(greatest_info.jump_dest))) -#else -#define GREATEST_SAVE_CONTEXT() \ - /*a no-op, since setjmp/longjmp aren't being used */ \ - GREATEST_TEST_RES_PASS -#endif - -/* Include several function definitions in the main test file. */ -#define GREATEST_MAIN_DEFS() \ - \ -/* Is FILTER a subset of NAME? */ \ -static int greatest_name_match(const char *name, \ - const char *filter) { \ - size_t offset = 0; \ - size_t filter_len = strlen(filter); \ - while (name[offset] != '\0') { \ - if (name[offset] == filter[0]) { \ - if (0 == strncmp(&name[offset], filter, filter_len)) { \ - return 1; \ - } \ - } \ - offset++; \ - } \ - \ - return 0; \ -} \ - \ -int greatest_pre_test(const char *name) { \ - if (!GREATEST_LIST_ONLY() \ - && (!GREATEST_FIRST_FAIL() || greatest_info.suite.failed == 0) \ - && (greatest_info.test_filter == NULL || \ - greatest_name_match(name, greatest_info.test_filter))) { \ - GREATEST_SET_TIME(greatest_info.suite.pre_test); \ - if (greatest_info.setup) { \ - greatest_info.setup(greatest_info.setup_udata); \ - } \ - return 1; /* test should be run */ \ - } else { \ - return 0; /* skipped */ \ - } \ -} \ - \ -void greatest_post_test(const char *name, int res) { \ - GREATEST_SET_TIME(greatest_info.suite.post_test); \ - if (greatest_info.teardown) { \ - void *udata = greatest_info.teardown_udata; \ - greatest_info.teardown(udata); \ - } \ - \ - if (res <= GREATEST_TEST_RES_FAIL) { \ - greatest_do_fail(name); \ - } else if (res >= GREATEST_TEST_RES_SKIP) { \ - greatest_do_skip(name); \ - } else if (res == GREATEST_TEST_RES_PASS) { \ - greatest_do_pass(name); \ - } \ - greatest_info.suite.tests_run++; \ - greatest_info.col++; \ - if (GREATEST_IS_VERBOSE()) { \ - GREATEST_CLOCK_DIFF(greatest_info.suite.pre_test, \ - greatest_info.suite.post_test); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } else if (greatest_info.col % greatest_info.width == 0) { \ - fprintf(GREATEST_STDOUT, "\n"); \ - greatest_info.col = 0; \ - } \ - if (GREATEST_STDOUT == stdout) fflush(stdout); \ -} \ - \ -static void report_suite(void) { \ - if (greatest_info.suite.tests_run > 0) { \ - fprintf(GREATEST_STDOUT, \ - "\n%u test%s - %u passed, %u failed, %u skipped", \ - greatest_info.suite.tests_run, \ - greatest_info.suite.tests_run == 1 ? "" : "s", \ - greatest_info.suite.passed, \ - greatest_info.suite.failed, \ - greatest_info.suite.skipped); \ - GREATEST_CLOCK_DIFF(greatest_info.suite.pre_suite, \ - greatest_info.suite.post_suite); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } \ -} \ - \ -static void update_counts_and_reset_suite(void) { \ - greatest_info.setup = NULL; \ - greatest_info.setup_udata = NULL; \ - greatest_info.teardown = NULL; \ - greatest_info.teardown_udata = NULL; \ - greatest_info.passed += greatest_info.suite.passed; \ - greatest_info.failed += greatest_info.suite.failed; \ - greatest_info.skipped += greatest_info.suite.skipped; \ - greatest_info.tests_run += greatest_info.suite.tests_run; \ - memset(&greatest_info.suite, 0, sizeof(greatest_info.suite)); \ - greatest_info.col = 0; \ -} \ - \ -static void greatest_run_suite(greatest_suite_cb *suite_cb, \ - const char *suite_name) { \ - if (greatest_info.suite_filter && \ - !greatest_name_match(suite_name, greatest_info.suite_filter)) { \ - return; \ - } \ - update_counts_and_reset_suite(); \ - if (GREATEST_FIRST_FAIL() && greatest_info.failed > 0) { return; } \ - fprintf(GREATEST_STDOUT, "\n* Suite %s:\n", suite_name); \ - GREATEST_SET_TIME(greatest_info.suite.pre_suite); \ - suite_cb(); \ - GREATEST_SET_TIME(greatest_info.suite.post_suite); \ - report_suite(); \ -} \ - \ -void greatest_do_pass(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, "PASS %s: %s", \ - name, greatest_info.msg ? greatest_info.msg : ""); \ - } else { \ - fprintf(GREATEST_STDOUT, "."); \ - } \ - greatest_info.suite.passed++; \ -} \ - \ -void greatest_do_fail(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, \ - "FAIL %s: %s (%s:%u)", \ - name, greatest_info.msg ? greatest_info.msg : "", \ - greatest_info.fail_file, greatest_info.fail_line); \ - } else { \ - fprintf(GREATEST_STDOUT, "F"); \ - greatest_info.col++; \ - /* add linebreak if in line of '.'s */ \ - if (greatest_info.col != 0) { \ - fprintf(GREATEST_STDOUT, "\n"); \ - greatest_info.col = 0; \ - } \ - fprintf(GREATEST_STDOUT, "FAIL %s: %s (%s:%u)\n", \ - name, \ - greatest_info.msg ? greatest_info.msg : "", \ - greatest_info.fail_file, greatest_info.fail_line); \ - } \ - greatest_info.suite.failed++; \ -} \ - \ -void greatest_do_skip(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, "SKIP %s: %s", \ - name, \ - greatest_info.msg ? \ - greatest_info.msg : "" ); \ - } else { \ - fprintf(GREATEST_STDOUT, "s"); \ - } \ - greatest_info.suite.skipped++; \ -} \ - \ -int greatest_do_assert_equal_t(const void *exp, const void *got, \ - greatest_type_info *type_info, void *udata) { \ - int eq = 0; \ - if (type_info == NULL || type_info->equal == NULL) { \ - return 0; \ - } \ - eq = type_info->equal(exp, got, udata); \ - if (!eq) { \ - if (type_info->print != NULL) { \ - fprintf(GREATEST_STDOUT, "\nExpected: "); \ - (void)type_info->print(exp, udata); \ - fprintf(GREATEST_STDOUT, "\n Got: "); \ - (void)type_info->print(got, udata); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } else { \ - fprintf(GREATEST_STDOUT, \ - "GREATEST_ASSERT_EQUAL_T failure at %s:%u\n", \ - greatest_info.fail_file, \ - greatest_info.fail_line); \ - } \ - } \ - return eq; \ -} \ - \ -void greatest_usage(const char *name) { \ - fprintf(GREATEST_STDOUT, \ - "Usage: %s [-hlfv] [-s SUITE] [-t TEST]\n" \ - " -h, --help print this Help\n" \ - " -l List suites and their tests, then exit\n" \ - " -f Stop runner after first failure\n" \ - " -v Verbose output\n" \ - " -s SUITE only run suites containing string SUITE\n" \ - " -t TEST only run tests containing string TEST\n", \ - name); \ -} \ - \ -static void greatest_parse_args(int argc, char **argv) { \ - int i = 0; \ - for (i = 1; i < argc; i++) { \ - if (0 == strncmp("-t", argv[i], 2)) { \ - if (argc <= i + 1) { \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - greatest_info.test_filter = argv[i+1]; \ - i++; \ - } else if (0 == strncmp("-s", argv[i], 2)) { \ - if (argc <= i + 1) { \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - greatest_info.suite_filter = argv[i+1]; \ - i++; \ - } else if (0 == strncmp("-f", argv[i], 2)) { \ - greatest_info.flags |= GREATEST_FLAG_FIRST_FAIL; \ - } else if (0 == strncmp("-v", argv[i], 2)) { \ - greatest_info.verbosity++; \ - } else if (0 == strncmp("-l", argv[i], 2)) { \ - greatest_info.flags |= GREATEST_FLAG_LIST_ONLY; \ - } else if (0 == strncmp("-h", argv[i], 2) || \ - 0 == strncmp("--help", argv[i], 6)) { \ - greatest_usage(argv[0]); \ - exit(EXIT_SUCCESS); \ - } else if (0 == strncmp("--", argv[i], 2)) { \ - break; \ - } else { \ - fprintf(GREATEST_STDOUT, \ - "Unknown argument '%s'\n", argv[i]); \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - } \ -} \ - \ -int greatest_all_passed(void) { return (greatest_info.failed == 0); } \ - \ -void greatest_set_test_filter(const char *name) { \ - greatest_info.test_filter = name; \ -} \ - \ -void greatest_set_suite_filter(const char *name) { \ - greatest_info.suite_filter = name; \ -} \ - \ -void greatest_get_report(struct greatest_report_t *report) { \ - if (report) { \ - report->passed = greatest_info.passed; \ - report->failed = greatest_info.failed; \ - report->skipped = greatest_info.skipped; \ - report->assertions = greatest_info.assertions; \ - } \ -} \ - \ -unsigned int greatest_get_verbosity(void) { \ - return greatest_info.verbosity; \ -} \ - \ -void greatest_set_verbosity(unsigned int verbosity) { \ - greatest_info.verbosity = (unsigned char)verbosity; \ -} \ - \ -void greatest_set_flag(greatest_flag_t flag) { \ - greatest_info.flags |= flag; \ -} \ - \ -void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata) { \ - greatest_info.setup = cb; \ - greatest_info.setup_udata = udata; \ -} \ - \ -void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, \ - void *udata) { \ - greatest_info.teardown = cb; \ - greatest_info.teardown_udata = udata; \ -} \ - \ -static int greatest_string_equal_cb(const void *exp, const void *got, \ - void *udata) { \ - size_t *size = (size_t *)udata; \ - return (size != NULL \ - ? (0 == strncmp((const char *)exp, (const char *)got, *size)) \ - : (0 == strcmp((const char *)exp, (const char *)got))); \ -} \ - \ -static int greatest_string_printf_cb(const void *t, void *udata) { \ - (void)udata; /* note: does not check \0 termination. */ \ - return fprintf(GREATEST_STDOUT, "%s", (const char *)t); \ -} \ - \ -greatest_type_info greatest_type_info_string = { \ - greatest_string_equal_cb, \ - greatest_string_printf_cb, \ -}; \ - \ -static int greatest_memory_equal_cb(const void *exp, const void *got, \ - void *udata) { \ - greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ - return (0 == memcmp(exp, got, env->size)); \ -} \ - \ -static int greatest_memory_printf_cb(const void *t, void *udata) { \ - greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ - unsigned char *buf = (unsigned char *)t, diff_mark = ' '; \ - FILE *out = GREATEST_STDOUT; \ - size_t i, line_i, line_len = 0; \ - int len = 0; /* format hexdump with differences highlighted */ \ - for (i = 0; i < env->size; i+= line_len) { \ - diff_mark = ' '; \ - line_len = env->size - i; \ - if (line_len > 16) { line_len = 16; } \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - if (env->exp[line_i] != env->got[line_i]) diff_mark = 'X'; \ - } \ - len += fprintf(out, "\n%04x %c ", (unsigned int)i, diff_mark); \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - int m = env->exp[line_i] == env->got[line_i]; /* match? */ \ - len += fprintf(out, "%02x%c", buf[line_i], m ? ' ' : '<'); \ - } \ - for (line_i = 0; line_i < 16 - line_len; line_i++) { \ - len += fprintf(out, " "); \ - } \ - fprintf(out, " "); \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - unsigned char c = buf[line_i]; \ - len += fprintf(out, "%c", isprint(c) ? c : '.'); \ - } \ - } \ - len += fprintf(out, "\n"); \ - return len; \ -} \ - \ -greatest_type_info greatest_type_info_memory = { \ - greatest_memory_equal_cb, \ - greatest_memory_printf_cb, \ -}; \ - \ -greatest_run_info greatest_info - -/* Init internals. */ -#define GREATEST_INIT() \ - do { \ - /* Suppress unused function warning if features aren't used */ \ - (void)greatest_run_suite; \ - (void)greatest_parse_args; \ - \ - memset(&greatest_info, 0, sizeof(greatest_info)); \ - greatest_info.width = GREATEST_DEFAULT_WIDTH; \ - GREATEST_SET_TIME(greatest_info.begin); \ - } while (0) \ - -/* Handle command-line arguments, etc. */ -#define GREATEST_MAIN_BEGIN() \ - do { \ - GREATEST_INIT(); \ - greatest_parse_args(argc, argv); \ - } while (0) - -/* Report passes, failures, skipped tests, the number of - * assertions, and the overall run time. */ -#define GREATEST_PRINT_REPORT() \ - do { \ - if (!GREATEST_LIST_ONLY()) { \ - update_counts_and_reset_suite(); \ - GREATEST_SET_TIME(greatest_info.end); \ - fprintf(GREATEST_STDOUT, \ - "\nTotal: %u test%s", \ - greatest_info.tests_run, \ - greatest_info.tests_run == 1 ? "" : "s"); \ - GREATEST_CLOCK_DIFF(greatest_info.begin, \ - greatest_info.end); \ - fprintf(GREATEST_STDOUT, ", %u assertion%s\n", \ - greatest_info.assertions, \ - greatest_info.assertions == 1 ? "" : "s"); \ - fprintf(GREATEST_STDOUT, \ - "Pass: %u, fail: %u, skip: %u.\n", \ - greatest_info.passed, \ - greatest_info.failed, greatest_info.skipped); \ - } \ - } while (0) - -/* Report results, exit with exit status based on results. */ -#define GREATEST_MAIN_END() \ - do { \ - GREATEST_PRINT_REPORT(); \ - return (greatest_all_passed() ? EXIT_SUCCESS : EXIT_FAILURE); \ - } while (0) - -/* Make abbreviations without the GREATEST_ prefix for the - * most commonly used symbols. */ -#if GREATEST_USE_ABBREVS -#define TEST GREATEST_TEST -#define SUITE GREATEST_SUITE -#define SUITE_EXTERN GREATEST_SUITE_EXTERN -#define RUN_TEST GREATEST_RUN_TEST -#define RUN_TEST1 GREATEST_RUN_TEST1 -#define RUN_SUITE GREATEST_RUN_SUITE -#define IGNORE_TEST GREATEST_IGNORE_TEST -#define ASSERT GREATEST_ASSERT -#define ASSERTm GREATEST_ASSERTm -#define ASSERT_FALSE GREATEST_ASSERT_FALSE -#define ASSERT_EQ GREATEST_ASSERT_EQ -#define ASSERT_EQ_FMT GREATEST_ASSERT_EQ_FMT -#define ASSERT_IN_RANGE GREATEST_ASSERT_IN_RANGE -#define ASSERT_EQUAL_T GREATEST_ASSERT_EQUAL_T -#define ASSERT_STR_EQ GREATEST_ASSERT_STR_EQ -#define ASSERT_STRN_EQ GREATEST_ASSERT_STRN_EQ -#define ASSERT_MEM_EQ GREATEST_ASSERT_MEM_EQ -#define ASSERT_ENUM_EQ GREATEST_ASSERT_ENUM_EQ -#define ASSERT_FALSEm GREATEST_ASSERT_FALSEm -#define ASSERT_EQm GREATEST_ASSERT_EQm -#define ASSERT_EQ_FMTm GREATEST_ASSERT_EQ_FMTm -#define ASSERT_IN_RANGEm GREATEST_ASSERT_IN_RANGEm -#define ASSERT_EQUAL_Tm GREATEST_ASSERT_EQUAL_Tm -#define ASSERT_STR_EQm GREATEST_ASSERT_STR_EQm -#define ASSERT_STRN_EQm GREATEST_ASSERT_STRN_EQm -#define ASSERT_MEM_EQm GREATEST_ASSERT_MEM_EQm -#define ASSERT_ENUM_EQm GREATEST_ASSERT_ENUM_EQm -#define PASS GREATEST_PASS -#define FAIL GREATEST_FAIL -#define SKIP GREATEST_SKIP -#define PASSm GREATEST_PASSm -#define FAILm GREATEST_FAILm -#define SKIPm GREATEST_SKIPm -#define SET_SETUP GREATEST_SET_SETUP_CB -#define SET_TEARDOWN GREATEST_SET_TEARDOWN_CB -#define CHECK_CALL GREATEST_CHECK_CALL - -#ifdef GREATEST_VA_ARGS -#define RUN_TESTp GREATEST_RUN_TESTp -#endif - -#if GREATEST_USE_LONGJMP -#define ASSERT_OR_LONGJMP GREATEST_ASSERT_OR_LONGJMP -#define ASSERT_OR_LONGJMPm GREATEST_ASSERT_OR_LONGJMPm -#define FAIL_WITH_LONGJMP GREATEST_FAIL_WITH_LONGJMP -#define FAIL_WITH_LONGJMPm GREATEST_FAIL_WITH_LONGJMPm -#endif - -#endif /* USE_ABBREVS */ - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/src/common/thirdparty/patches/.gitattributes b/src/common/thirdparty/patches/.gitattributes deleted file mode 100644 index 9812ceb1ffd9b..0000000000000 --- a/src/common/thirdparty/patches/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -*.patch text eol=lf diff --git a/src/common/thirdparty/patches/windows/python-pyconfig.patch b/src/common/thirdparty/patches/windows/python-pyconfig.patch deleted file mode 100644 index 4280dee774702..0000000000000 --- a/src/common/thirdparty/patches/windows/python-pyconfig.patch +++ /dev/null @@ -1,25 +0,0 @@ -diff --git a/inc/Windows/pyconfig.h b/inc/Windows/pyconfig.h -index 1cfc59b..d4861cb ---- a/inc/Windows/pyconfig.h -+++ b/inc/Windows/pyconfig.h -@@ -1,6 +1,11 @@ - #ifndef Py_CONFIG_H - #define Py_CONFIG_H - -+#ifdef _MSC_VER -+#pragma push_macro("_DEBUG") -+#undef _DEBUG -+#endif -+ - /* pyconfig.h. NOT Generated automatically by configure. - - This is a manually maintained version used for the Watcom, -@@ -756,4 +761,8 @@ Py_NO_ENABLE_SHARED to find out. Also support MS_NO_COREDLL for b/w compat */ - least significant byte first */ - #define DOUBLE_IS_LITTLE_ENDIAN_IEEE754 1 - -+#ifdef _MSC_VER -+#pragma pop_macro("_DEBUG") -+#endif -+ - #endif /* !Py_CONFIG_H */ diff --git a/src/common/thirdparty/patches/windows/redis.patch b/src/common/thirdparty/patches/windows/redis.patch deleted file mode 100644 index 5ed2df5105cf3..0000000000000 --- a/src/common/thirdparty/patches/windows/redis.patch +++ /dev/null @@ -1,772 +0,0 @@ -diff --git a/msvs/RedisServer.vcxproj b/msvs/RedisServer.vcxproj -index 115ce90..68afb44 ---- a/msvs/RedisServer.vcxproj -+++ b/msvs/RedisServer.vcxproj -@@ -24,26 +24,26 @@ - - - -- Application -+ StaticLibrary - true -- v120 -+ v140_xp - false - - -- Application -+ StaticLibrary - true -- v120 -+ v140_xp - false - - -- Application -+ StaticLibrary - false -- v120 -+ v140_xp - - -- Application -+ StaticLibrary - false -- v120 -+ v140_xp - - - -@@ -61,41 +61,23 @@ - - - -- -+ - false - redis-server - false -- -- -- false -- redis-server -- false -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- false -- redis-server -- false -- Build -- -- -- false -- redis-server -- false -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - Disabled - 4996;4146 -- true -+ false -+ true - - - true -@@ -109,14 +91,14 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - Disabled - 4996;4146 -- true -+ false -+ true - - - true -@@ -130,14 +112,13 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - 4996;4146 -- true - Full -+ true - - - true -@@ -162,13 +143,12 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;_WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;_WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - 4996;4146 -- true -+ true - - - true -@@ -271,9 +251,6 @@ - - - -- -- {8b897e33-6428-4254-8335-4911d179bad1} -- - - {8c07f811-c81c-432c-b334-1ae6faecf951} - -diff --git a/msvs/hiredis/hiredis.vcxproj b/msvs/hiredis/hiredis.vcxproj -index 0622958..efaedae ---- a/msvs/hiredis/hiredis.vcxproj -+++ b/msvs/hiredis/hiredis.vcxproj -@@ -28,27 +28,25 @@ - StaticLibrary - true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false -- true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false -- true - MultiByte -- v120 -+ v140_xp - - - -@@ -66,30 +64,20 @@ - - - -- -+ - hiredis -- -- -- hiredis -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- hiredis -- -- -- hiredis -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(ProjectDir)..\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - - NotUsing - Level3 - Disabled -- _OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) - 4996 -+ false -+ true - - - Windows -@@ -101,9 +89,10 @@ - NotUsing - Level3 - Disabled -- _OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) - 4996 -+ false -+ true - - - Windows -@@ -117,10 +106,9 @@ - Full - true - true -- _OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) - 4996 -- true -+ true - - - Windows -@@ -136,10 +124,9 @@ - Full - true - true -- _OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) - 4996 -- true -+ true - - - Windows -diff --git a/msvs/lua/lua/lua.vcxproj b/msvs/lua/lua/lua.vcxproj -index b187130..adef07b ---- a/msvs/lua/lua/lua.vcxproj -+++ b/msvs/lua/lua/lua.vcxproj -@@ -30,28 +30,28 @@ - true - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - true - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false - false - MultiByte -- v120 -+ v140_xp - - - -@@ -69,25 +69,16 @@ - - - -- -+ - true -- .lib -- -- -- true -- .lib -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ - -- -+ - false -- .lib - -- -- false -+ - .lib -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -@@ -95,8 +86,9 @@ - Disabled - _OFF_T_DEFINED;WIN32;_DEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreadedDebug - 4244;4018 -+ false -+ true - - - true -@@ -110,8 +102,9 @@ - Disabled - _OFF_T_DEFINED;WIN32;_DEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501;LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreadedDebug - 4244;4018 -+ false -+ true - - - true -@@ -124,10 +117,10 @@ - Level3 - _OFF_T_DEFINED;WIN32;NDEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreaded - 4244;4018 - Full - true -+ true - - - true -@@ -140,8 +133,8 @@ - Level3 - _OFF_T_DEFINED;WIN32;NDEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501;LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreaded - 4244;4018 -+ true - - - true -diff --git a/src/Win32_Interop/Win32_ANSI.c b/src/Win32_Interop/Win32_ANSI.c -index 404b84f..e7c55d2 ---- a/src/Win32_Interop/Win32_ANSI.c -+++ b/src/Win32_Interop/Win32_ANSI.c -@@ -737,7 +737,7 @@ void ANSI_printf(char *format, ...) { - memset(buffer, 0, cBufLen); - - va_start(args, format); -- retVal = vsprintf_s(buffer, cBufLen, format, args); -+ retVal = vsnprintf(buffer, cBufLen - 1, format, args); - va_end(args); - - if (retVal > 0) { -diff --git a/src/Win32_Interop/Win32_EventLog.cpp b/src/Win32_Interop/Win32_EventLog.cpp -index 1856540..3db4ddd ---- a/src/Win32_Interop/Win32_EventLog.cpp -+++ b/src/Win32_Interop/Win32_EventLog.cpp -@@ -30,7 +30,6 @@ using namespace std; - - #include "Win32_EventLog.h" - #include "Win32_SmartHandle.h" --#include "EventLog.h" - - static bool eventLogEnabled = true; - static string eventLogIdentity = "redis"; -@@ -129,17 +128,17 @@ void RedisEventLog::LogMessage(LPCSTR msg, const WORD type) { - DWORD eventID; - switch (type) { - case EVENTLOG_ERROR_TYPE: -- eventID = MSG_ERROR_1; -+ eventID = 0x2; - break; - case EVENTLOG_WARNING_TYPE: -- eventID = MSG_WARNING_1; -+ eventID = 0x1; - break; - case EVENTLOG_INFORMATION_TYPE: -- eventID = MSG_INFO_1; -+ eventID = 0x0; - break; - default: - std::cerr << "Unrecognized type: " << type << "\n"; -- eventID = MSG_INFO_1; -+ eventID = 0x0; - break; - } - -diff --git a/src/Win32_Interop/Win32_FDAPI.cpp b/src/Win32_Interop/Win32_FDAPI.cpp -index 3df9af1..f60e3d4 ---- a/src/Win32_Interop/Win32_FDAPI.cpp -+++ b/src/Win32_Interop/Win32_FDAPI.cpp -@@ -46,11 +46,13 @@ fdapi_access access = NULL; - fdapi_bind bind = NULL; - fdapi_connect connect = NULL; - fdapi_fcntl fcntl = NULL; -+fdapi_ioctl ioctl = NULL; - fdapi_fstat fdapi_fstat64 = NULL; - fdapi_fsync fsync = NULL; - fdapi_ftruncate ftruncate = NULL; - fdapi_freeaddrinfo freeaddrinfo = NULL; - fdapi_getaddrinfo getaddrinfo = NULL; -+fdapi_gethostbyname gethostbyname = NULL; - fdapi_getpeername getpeername = NULL; - fdapi_getsockname getsockname = NULL; - fdapi_getsockopt getsockopt = NULL; -@@ -67,7 +69,9 @@ fdapi_open open = NULL; - fdapi_pipe pipe = NULL; - fdapi_poll poll = NULL; - fdapi_read read = NULL; -+fdapi_recv recv = NULL; - fdapi_select select = NULL; -+fdapi_send send = NULL; - fdapi_setsockopt setsockopt = NULL; - fdapi_socket socket = NULL; - fdapi_write write = NULL; -@@ -622,6 +626,23 @@ int FDAPI_fcntl(int rfd, int cmd, int flags = 0 ) { - return -1; - } - -+int FDAPI_ioctl(int rfd, int cmd, char *buf) { -+ try { -+ SocketInfo* socket_info = RFDMap::getInstance().lookupSocketInfo(rfd); -+ if (socket_info != NULL && socket_info->socket != INVALID_SOCKET) { -+ if (f_ioctlsocket(socket_info->socket, cmd, (u_long *)buf) != SOCKET_ERROR) { -+ return 0; -+ } else { -+ errno = f_WSAGetLastError(); -+ return -1; -+ } -+ } -+ } CATCH_AND_REPORT(); -+ -+ errno = EBADF; -+ return -1; -+} -+ - int FDAPI_poll(struct pollfd *fds, nfds_t nfds, int timeout) { - try { - struct pollfd* pollCopy = new struct pollfd[nfds]; -@@ -777,6 +798,42 @@ ssize_t FDAPI_read(int rfd, void *buf, size_t count) { - return -1; - } - -+ssize_t FDAPI_recv(int rfd, void *buf, size_t count, int flags) { -+ try { -+ SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -+ if (socket != INVALID_SOCKET) { -+ int retval = f_recv(socket, (char*) buf, (unsigned int) count, flags); -+ if (retval == -1) { -+ errno = GetLastError(); -+ if (errno == WSAEWOULDBLOCK) { -+ errno = EAGAIN; -+ } -+ } -+ return retval; -+ } -+ } CATCH_AND_REPORT(); -+ errno = EBADF; -+ return -1; -+} -+ -+ssize_t FDAPI_send(int rfd, const void *buf, size_t count, int flags) { -+ try { -+ SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -+ if (socket != INVALID_SOCKET) { -+ int retval = f_send(socket, (const char*) buf, (unsigned int) count, flags); -+ if (retval == -1) { -+ errno = GetLastError(); -+ if (errno == WSAEWOULDBLOCK) { -+ errno = EAGAIN; -+ } -+ } -+ return retval; -+ } -+ } CATCH_AND_REPORT(); -+ errno = EBADF; -+ return -1; -+} -+ - ssize_t FDAPI_write(int rfd, const void *buf, size_t count) { - try { - SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -@@ -1195,12 +1252,14 @@ private: - bind = FDAPI_bind; - connect = FDAPI_connect; - fcntl = FDAPI_fcntl; -+ ioctl = FDAPI_ioctl; - fdapi_fstat64 = (fdapi_fstat) FDAPI_fstat64; - freeaddrinfo = FDAPI_freeaddrinfo; - fsync = FDAPI_fsync; - ftruncate = FDAPI_ftruncate; - getaddrinfo = FDAPI_getaddrinfo; - getsockopt = FDAPI_getsockopt; -+ gethostbyname = FDAPI_gethostbyname; - getpeername = FDAPI_getpeername; - getsockname = FDAPI_getsockname; - htonl = FDAPI_htonl; -@@ -1216,9 +1275,11 @@ private: - pipe = FDAPI_pipe; - poll = FDAPI_poll; - read = FDAPI_read; -+ recv = FDAPI_recv; - select = FDAPI_select; - setsockopt = FDAPI_setsockopt; - socket = FDAPI_socket; -+ send = FDAPI_send; - write = FDAPI_write; - } - -diff --git a/src/Win32_Interop/Win32_FDAPI.h b/src/Win32_Interop/Win32_FDAPI.h -index 8fae9c7..6e09596 ---- a/src/Win32_Interop/Win32_FDAPI.h -+++ b/src/Win32_Interop/Win32_FDAPI.h -@@ -116,9 +116,12 @@ typedef int (*fdapi_open)(const char * _Filename, int _OpenFlag, int flags); - typedef int (*fdapi_accept)(int sockfd, struct sockaddr *addr, socklen_t *addrlen); - typedef int (*fdapi_setsockopt)(int sockfd, int level, int optname,const void *optval, socklen_t optlen); - typedef int (*fdapi_fcntl)(int fd, int cmd, int flags); -+typedef int (*fdapi_ioctl)(int fd, int cmd, char *buf); - typedef int (*fdapi_poll)(struct pollfd *fds, nfds_t nfds, int timeout); - typedef int (*fdapi_getsockopt)(int sockfd, int level, int optname, void *optval, socklen_t *optlen); - typedef int (*fdapi_connect)(int sockfd, const struct sockaddr *addr, size_t addrlen); -+typedef ssize_t (*fdapi_recv)(int fd, void *buf, size_t count, int flags); -+typedef ssize_t (*fdapi_send)(int rfd, void const *buf, size_t count, int flags); - typedef ssize_t (*fdapi_read)(int fd, void *buf, size_t count); - typedef ssize_t (*fdapi_write)(int fd, const void *buf, size_t count); - typedef int (*fdapi_fsync)(int fd); -@@ -128,6 +131,7 @@ typedef int (*fdapi_bind)(int sockfd, const struct sockaddr *addr, socklen_t add - typedef u_short (*fdapi_htons)(u_short hostshort); - typedef u_long (*fdapi_htonl)(u_long hostlong); - typedef u_short (*fdapi_ntohs)(u_short netshort); -+typedef struct hostent* (*fdapi_gethostbyname)(const char *name); - typedef int (*fdapi_getpeername)(int sockfd, struct sockaddr *addr, socklen_t * addrlen); - typedef int (*fdapi_getsockname)(int sockfd, struct sockaddr* addrsock, int* addrlen ); - typedef void (*fdapi_freeaddrinfo)(struct addrinfo *ai); -@@ -159,12 +163,14 @@ extern fdapi_access access; - extern fdapi_bind bind; - extern fdapi_connect connect; - extern fdapi_fcntl fcntl; -+extern fdapi_ioctl ioctl; - extern fdapi_fstat fdapi_fstat64; - extern fdapi_freeaddrinfo freeaddrinfo; - extern fdapi_fsync fsync; - extern fdapi_ftruncate ftruncate; - extern fdapi_getaddrinfo getaddrinfo; - extern fdapi_getsockopt getsockopt; -+extern fdapi_gethostbyname gethostbyname; - extern fdapi_getpeername getpeername; - extern fdapi_getsockname getsockname; - extern fdapi_htonl htonl; -@@ -180,7 +186,9 @@ extern fdapi_open open; - extern fdapi_pipe pipe; - extern fdapi_poll poll; - extern fdapi_read read; -+extern fdapi_recv recv; - extern fdapi_select select; -+extern fdapi_send send; - extern fdapi_setsockopt setsockopt; - extern fdapi_socket socket; - extern fdapi_write write; -diff --git a/src/Win32_Interop/Win32_Interop.vcxproj b/src/Win32_Interop/Win32_Interop.vcxproj -index 93fc44b..b75d89b ---- a/src/Win32_Interop/Win32_Interop.vcxproj -+++ b/src/Win32_Interop/Win32_Interop.vcxproj -@@ -74,35 +74,6 @@ - - - -- -- -- Document -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- EventLog.h -- EventLog.h -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- EventLog.h -- EventLog.h -- -- - - {8C07F811-C81C-432C-B334-1AE6FAECF951} - Win32Proj -@@ -113,27 +84,25 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - StaticLibrary - true -- v120 -+ v140_xp - Unicode - - - StaticLibrary - true -- v120 -+ v140_xp - Unicode - - - StaticLibrary - false -- v120 -- true -+ v140_xp - Unicode - - - StaticLibrary - false -- v120 -- true -+ v140_xp - Unicode - - -@@ -152,13 +121,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - - -- -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -@@ -166,9 +131,10 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - Level3 - Disabled -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreadedDebug -+ false -+ true - - - Windows -@@ -186,9 +152,10 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - Level3 - Disabled -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreadedDebug -+ false -+ true - - - Windows -@@ -211,10 +178,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - Full - true - true -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreaded -- true -+ true - - - Windows -@@ -235,9 +201,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - MaxSpeed - true - true -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreaded -+ true - - - Windows -diff --git a/src/Win32_Interop/Win32_service.cpp b/src/Win32_Interop/Win32_service.cpp -index 488538e..1c33f53 ---- a/src/Win32_Interop/Win32_service.cpp -+++ b/src/Win32_Interop/Win32_service.cpp -@@ -59,7 +59,6 @@ this should preceed the other arguments passed to redis. For instance: - #include - #include - #include --#include - #include - #include "Win32_EventLog.h" - #include -diff --git a/src/ziplist.c b/src/ziplist.c -index 24b0a7c..29d445d ---- a/src/ziplist.c -+++ b/src/ziplist.c -@@ -920,7 +920,7 @@ void ziplistRepr(unsigned char *zl) { - entry = zipEntry(p); - printf( - "{" -- "addr 0x%08lx, " /* TODO" verify 0x%08lx */ -+ "addr %p, " - "index %2d, " - "offset %5ld, " - "rl: %5u, " -@@ -929,9 +929,9 @@ void ziplistRepr(unsigned char *zl) { - "pls: %2u, " - "payload %5u" - "} ", -- (PORT_ULONG)p, -+ (void *)p, - index, -- (PORT_ULONG)(p-zl), -+ (long)(p-zl), - entry.headersize+entry.len, - entry.headersize, - entry.prevrawlen, diff --git a/src/global_scheduler/CMakeLists.txt b/src/global_scheduler/CMakeLists.txt deleted file mode 100644 index fec7ec2810d9c..0000000000000 --- a/src/global_scheduler/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(global_scheduler) - -include_directories(${CMAKE_CURRENT_LIST_DIR}) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -add_executable(global_scheduler global_scheduler.cc global_scheduler_algorithm.cc) - -# Make sure ${HIREDIS_LIB} is ready before linking. -add_dependencies(global_scheduler hiredis common) - -target_link_libraries(global_scheduler common ${HIREDIS_LIB} ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY} pthread) diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc deleted file mode 100644 index 069ad6865d174..0000000000000 --- a/src/global_scheduler/global_scheduler.cc +++ /dev/null @@ -1,492 +0,0 @@ -#include -#include -#include - -#include "common.h" -#include "event_loop.h" -#include "global_scheduler.h" -#include "global_scheduler_algorithm.h" -#include "net.h" -#include "ray/util/util.h" -#include "state/db_client_table.h" -#include "state/local_scheduler_table.h" -#include "state/object_table.h" -#include "state/table.h" -#include "state/task_table.h" - -/** - * Retry the task assignment. If the local scheduler that the task is assigned - * to is no longer active, do not retry the assignment. - * TODO(rkn): We currently only retry the method if the global scheduler - * publishes a task to a local scheduler before the local scheduler has - * subscribed to the channel. If we enforce that ordering, we can remove this - * retry method. - * - * @param id The task ID. - * @param user_context The global scheduler state. - * @param user_data The Task that failed to be assigned. - * @return Void. - */ -void assign_task_to_local_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::SCHEDULED); - - // If the local scheduler has died since we requested the task assignment, do - // not retry again. - DBClientID local_scheduler_id = Task_local_scheduler(task); - auto it = state->local_schedulers.find(local_scheduler_id); - if (it == state->local_schedulers.end()) { - return; - } - - // The local scheduler is still alive. The failure is most likely due to the - // task assignment getting published before the local scheduler subscribed to - // the channel. Retry the assignment. - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = assign_task_to_local_scheduler_retry, - }; - task_table_update(state->db, Task_copy(task), &retryInfo, NULL, user_context); -} - -/** - * Assign the given task to the local scheduler, update Redis and scheduler data - * structures. - * - * @param state Global scheduler state. - * @param task Task to be assigned to the local scheduler. - * @param local_scheduler_id DB client ID for the local scheduler. - * @return Void. - */ -void assign_task_to_local_scheduler(GlobalSchedulerState *state, - Task *task, - DBClientID local_scheduler_id) { - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - RAY_LOG(DEBUG) << "assigning task to local_scheduler_id = " - << local_scheduler_id; - Task_set_state(task, TaskStatus::SCHEDULED); - Task_set_local_scheduler(task, local_scheduler_id); - RAY_LOG(DEBUG) << "Issuing a task table update for task = " - << Task_task_id(task); - - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = assign_task_to_local_scheduler_retry, - }; - task_table_update(state->db, Task_copy(task), &retryInfo, NULL, state); - - /* Update the object table info to reflect the fact that the results of this - * task will be created on the machine that the task was assigned to. This can - * be used to improve locality-aware scheduling. */ - for (int64_t i = 0; i < TaskSpec_num_returns(spec); ++i) { - ObjectID return_id = TaskSpec_return(spec, i); - if (state->scheduler_object_info_table.find(return_id) == - state->scheduler_object_info_table.end()) { - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[return_id]; - /* The value -1 indicates that the size of the object is not known yet. */ - obj_info_entry.data_size = -1; - } - RAY_CHECK(state->local_scheduler_plasma_map.count(local_scheduler_id) == 1); - state->scheduler_object_info_table[return_id].object_locations.push_back( - state->local_scheduler_plasma_map[local_scheduler_id]); - } - - /* TODO(rkn): We should probably pass around local_scheduler struct pointers - * instead of db_client_id objects. */ - /* Update the local scheduler info. */ - auto it = state->local_schedulers.find(local_scheduler_id); - RAY_CHECK(it != state->local_schedulers.end()); - - LocalScheduler &local_scheduler = it->second; - local_scheduler.num_tasks_sent += 1; - local_scheduler.num_recent_tasks_sent += 1; - // Resource accounting update for this local scheduler. - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - // The local scheduler must have this resource because otherwise we wouldn't - // be assigning the task to this local scheduler. - RAY_CHECK(local_scheduler.info.dynamic_resources.count(resource_name) == - 1 || - resource_quantity == 0); - // Subtract task's resource from the cached dynamic resource capacity for - // this local scheduler. This will be overwritten on the next heartbeat. - local_scheduler.info.dynamic_resources[resource_name] = - MAX(0, local_scheduler.info.dynamic_resources[resource_name] - - resource_quantity); - } -} - -GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, - const char *node_ip_address, - const char *redis_primary_addr, - int redis_primary_port) { - GlobalSchedulerState *state = new GlobalSchedulerState(); - state->loop = loop; - state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, - "global_scheduler", node_ip_address, - std::vector()); - db_attach(state->db, loop, false); - state->policy_state = GlobalSchedulerPolicyState_init(); - return state; -} - -void GlobalSchedulerState_free(GlobalSchedulerState *state) { - db_disconnect(state->db); - state->local_schedulers.clear(); - GlobalSchedulerPolicyState_free(state->policy_state); - /* Delete the plasma to local scheduler association map. */ - state->plasma_local_scheduler_map.clear(); - - /* Delete the local scheduler to plasma association map. */ - state->local_scheduler_plasma_map.clear(); - - /* Free the scheduler object info table. */ - state->scheduler_object_info_table.clear(); - /* Free the array of unschedulable tasks. */ - int64_t num_pending_tasks = state->pending_tasks.size(); - if (num_pending_tasks > 0) { - RAY_LOG(WARNING) << "There are " << num_pending_tasks - << " remaining tasks in the pending tasks array."; - } - for (int i = 0; i < num_pending_tasks; ++i) { - Task *pending_task = state->pending_tasks[i]; - Task_free(pending_task); - } - state->pending_tasks.clear(); - - /* Destroy the event loop. */ - destroy_outstanding_callbacks(state->loop); - event_loop_destroy(state->loop); - state->loop = NULL; - - /* Free the global scheduler state. */ - delete state; -} - -/* We need this code so we can clean up when we get a SIGTERM signal. */ - -GlobalSchedulerState *g_state; - -void signal_handler(int signal) { - if (signal == SIGTERM) { - GlobalSchedulerState_free(g_state); - exit(0); - } -} - -/* End of the cleanup code. */ - -void process_task_waiting(Task *waiting_task, void *user_context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "Task waiting callback is called."; - bool successfully_assigned = - handle_task_waiting(state, state->policy_state, waiting_task); - /* If the task was not successfully submitted to a local scheduler, add the - * task to the array of pending tasks. The global scheduler will periodically - * resubmit the tasks in this array. */ - if (!successfully_assigned) { - Task *task_copy = Task_copy(waiting_task); - state->pending_tasks.push_back(task_copy); - } -} - -void add_local_scheduler(GlobalSchedulerState *state, - DBClientID db_client_id, - const char *manager_address) { - /* Add plasma_manager ip:port -> local_scheduler_db_client_id association to - * state. */ - state->plasma_local_scheduler_map[std::string(manager_address)] = - db_client_id; - - /* Add local_scheduler_db_client_id -> plasma_manager ip:port association to - * state. */ - state->local_scheduler_plasma_map[db_client_id] = - std::string(manager_address); - - /* Add new local scheduler to the state. */ - LocalScheduler &local_scheduler = state->local_schedulers[db_client_id]; - local_scheduler.id = db_client_id; - local_scheduler.num_heartbeats_missed = 0; - local_scheduler.num_tasks_sent = 0; - local_scheduler.num_recent_tasks_sent = 0; - local_scheduler.info.task_queue_length = 0; - local_scheduler.info.available_workers = 0; - - /* Allow the scheduling algorithm to process this event. */ - handle_new_local_scheduler(state, state->policy_state, db_client_id); -} - -std::unordered_map::iterator remove_local_scheduler( - GlobalSchedulerState *state, - std::unordered_map::iterator it) { - RAY_CHECK(it != state->local_schedulers.end()); - DBClientID local_scheduler_id = it->first; - it = state->local_schedulers.erase(it); - - /* Remove the local scheduler from the mappings. This code only makes sense if - * there is a one-to-one mapping between local schedulers and plasma managers. - */ - std::string manager_address = - state->local_scheduler_plasma_map[local_scheduler_id]; - state->local_scheduler_plasma_map.erase(local_scheduler_id); - state->plasma_local_scheduler_map.erase(manager_address); - - handle_local_scheduler_removed(state, state->policy_state, - local_scheduler_id); - return it; -} - -/** - * Process a notification about a new DB client connecting to Redis. - * - * @param manager_address An ip:port pair for the plasma manager associated with - * this db client. - * @return Void. - */ -void process_new_db_client(DBClient *db_client, void *user_context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "db client table callback for db client = " - << db_client->id; - if (strncmp(db_client->client_type.c_str(), "local_scheduler", - strlen("local_scheduler")) == 0) { - bool local_scheduler_present = - (state->local_schedulers.find(db_client->id) != - state->local_schedulers.end()); - if (db_client->is_alive) { - /* This is a notification for an insert. We may receive duplicate - * notifications since we read the entire table before processing - * notifications. Filter out local schedulers that we already added. */ - if (!local_scheduler_present) { - add_local_scheduler(state, db_client->id, - db_client->manager_address.c_str()); - } - } else { - if (local_scheduler_present) { - remove_local_scheduler(state, - state->local_schedulers.find(db_client->id)); - } - } - } -} - -/** - * Process notification about the new object information. - * - * @param object_id ID of the object that the notification is about. - * @param data_size The object size. - * @param manager_count The number of locations for this object. - * @param manager_ids The vector of Plasma Manager client IDs. - * @param user_context The user context. - * @return Void. - */ -void object_table_subscribe_callback(ObjectID object_id, - int64_t data_size, - const std::vector &manager_ids, - void *user_context) { - /* Extract global scheduler state from the callback context. */ - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "object table subscribe callback for OBJECT = " - << object_id; - - const std::vector managers = - db_client_table_get_ip_addresses(state->db, manager_ids); - RAY_LOG(DEBUG) << "\tManagers<" << managers.size() << ">:"; - for (size_t i = 0; i < managers.size(); i++) { - RAY_LOG(DEBUG) << "\t\t" << managers[i]; - } - - if (state->scheduler_object_info_table.find(object_id) == - state->scheduler_object_info_table.end()) { - /* Construct a new object info hash table entry. */ - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[object_id]; - obj_info_entry.data_size = data_size; - - RAY_LOG(DEBUG) << "New object added to object_info_table with id = " - << object_id; - RAY_LOG(DEBUG) << "\tmanager locations:"; - for (size_t i = 0; i < managers.size(); i++) { - RAY_LOG(DEBUG) << "\t\t" << managers[i]; - } - } - - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[object_id]; - - /* In all cases, replace the object location vector on each callback. */ - obj_info_entry.object_locations.clear(); - for (size_t i = 0; i < managers.size(); i++) { - obj_info_entry.object_locations.push_back(managers[i]); - } -} - -void local_scheduler_table_handler(DBClientID client_id, - LocalSchedulerInfo info, - void *user_context) { - /* Extract global scheduler state from the callback context. */ - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - ARROW_UNUSED(state); - RAY_LOG(DEBUG) << "Local scheduler heartbeat from db_client_id " << client_id; - RAY_LOG(DEBUG) << "total workers = " << info.total_num_workers - << ", task queue length = " << info.task_queue_length - << ", available workers = " << info.available_workers; - - /* Update the local scheduler info struct. */ - auto it = state->local_schedulers.find(client_id); - if (it != state->local_schedulers.end()) { - if (info.is_dead) { - /* The local scheduler is exiting. Increase the number of heartbeats - * missed to the timeout threshold. This will trigger removal of the - * local scheduler the next time the timeout handler fires. */ - it->second.num_heartbeats_missed = - RayConfig::instance().num_heartbeats_timeout(); - } else { - /* Reset the number of tasks sent since the last heartbeat. */ - LocalScheduler &local_scheduler = it->second; - local_scheduler.num_heartbeats_missed = 0; - local_scheduler.num_recent_tasks_sent = 0; - local_scheduler.info = info; - } - } else { - RAY_LOG(WARNING) << "client_id didn't match any cached local scheduler " - << "entries"; - } -} - -int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) context; - /* Loop over the pending tasks in reverse order and resubmit them. */ - auto it = state->pending_tasks.end(); - while (it != state->pending_tasks.begin()) { - it--; - Task *pending_task = *it; - /* Pretend that the task has been resubmitted. */ - bool successfully_assigned = - handle_task_waiting(state, state->policy_state, pending_task); - if (successfully_assigned) { - /* The task was successfully assigned, so remove it from this list and - * free it. This uses the fact that pending_tasks is a vector and so erase - * returns an iterator to the next element in the vector. */ - it = state->pending_tasks.erase(it); - Task_free(pending_task); - } - } - - return GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS; -} - -int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) context; - /* Check for local schedulers that have missed a number of heartbeats. If any - * local schedulers have died, notify others so that the state can be cleaned - * up. */ - /* TODO(swang): If the local scheduler hasn't actually died, then it should - * clean up its state and exit upon receiving this notification. */ - auto it = state->local_schedulers.begin(); - while (it != state->local_schedulers.end()) { - if (it->second.num_heartbeats_missed >= - RayConfig::instance().num_heartbeats_timeout()) { - RAY_LOG(WARNING) << "Missed too many heartbeats from local scheduler, " - << "marking as dead."; - /* Notify others by updating the global state. */ - db_client_table_remove(state->db, it->second.id, NULL, NULL, NULL); - /* Remove the scheduler from the local state. The call to - * remove_local_scheduler modifies the container in place and returns the - * next iterator. */ - it = remove_local_scheduler(state, it); - } else { - it->second.num_heartbeats_missed += 1; - it++; - } - } - - /* Reset the timer. */ - return RayConfig::instance().heartbeat_timeout_milliseconds(); -} - -void start_server(const char *node_ip_address, - const char *redis_primary_addr, - int redis_primary_port) { - event_loop *loop = event_loop_create(); - g_state = GlobalSchedulerState_init(loop, node_ip_address, redis_primary_addr, - redis_primary_port); - /* TODO(rkn): subscribe to notifications from the object table. */ - /* Subscribe to notifications about new local schedulers. TODO(rkn): this - * needs to also get all of the clients that registered with the database - * before this call to subscribe. */ - db_client_table_subscribe(g_state->db, process_new_db_client, - (void *) g_state, NULL, NULL, NULL); - /* Subscribe to notifications about waiting tasks. If a local scheduler - * submits tasks to the global scheduler before the global scheduler - * successfully subscribes, then the local scheduler that submitted the tasks - * will retry. */ - task_table_subscribe(g_state->db, UniqueID::nil(), TaskStatus::WAITING, - process_task_waiting, (void *) g_state, NULL, NULL, - NULL); - - object_table_subscribe_to_notifications(g_state->db, true, - object_table_subscribe_callback, - g_state, NULL, NULL, NULL); - /* Subscribe to notifications from local schedulers. These notifications serve - * as heartbeats and contain informaion about the load on the local - * schedulers. */ - local_scheduler_table_subscribe(g_state->db, local_scheduler_table_handler, - g_state, NULL); - /* Start a timer that periodically checks if there are queued tasks that can - * be scheduled. Currently this is only used to handle the special case in - * which a task is waiting and no node meets its static resource requirements. - * If a new node joins the cluster that does have enough resources, then this - * timer should notice and schedule the task. */ - event_loop_add_timer(loop, GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS, - task_cleanup_handler, g_state); - event_loop_add_timer(loop, - RayConfig::instance().heartbeat_timeout_milliseconds(), - heartbeat_timeout_handler, g_state); - /* Start the event loop. */ - event_loop_run(loop); -} - -int main(int argc, char *argv[]) { - InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); - ray::RayLog::InstallFailureSignalHandler(); - signal(SIGTERM, signal_handler); - /* IP address and port of the primary redis instance. */ - char *redis_primary_addr_port = NULL; - /* The IP address of the node that this global scheduler is running on. */ - char *node_ip_address = NULL; - int c; - while ((c = getopt(argc, argv, "h:r:")) != -1) { - switch (c) { - case 'r': - redis_primary_addr_port = optarg; - break; - case 'h': - node_ip_address = optarg; - break; - default: - RAY_LOG(FATAL) << "unknown option " << c; - } - } - - char redis_primary_addr[16]; - int redis_primary_port = -1; - if (!redis_primary_addr_port || - parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, - &redis_primary_port) == -1) { - RAY_LOG(FATAL) << "specify the primary redis address like 127.0.0.1:6379 " - << "with the -r switch"; - } - if (!node_ip_address) { - RAY_LOG(FATAL) << "specify the node IP address with the -h switch"; - } - start_server(node_ip_address, redis_primary_addr, redis_primary_port); -} diff --git a/src/global_scheduler/global_scheduler.h b/src/global_scheduler/global_scheduler.h deleted file mode 100644 index e1610c555088c..0000000000000 --- a/src/global_scheduler/global_scheduler.h +++ /dev/null @@ -1,94 +0,0 @@ -#ifndef GLOBAL_SCHEDULER_H -#define GLOBAL_SCHEDULER_H - -#include "task.h" - -#include - -#include "ray/gcs/client.h" -#include "state/db.h" -#include "state/local_scheduler_table.h" - -/* The frequency with which the global scheduler checks if there are any tasks - * that haven't been scheduled yet. */ -#define GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS 100 - -/** Contains all information that is associated with a local scheduler. */ -typedef struct { - /** The ID of the local scheduler in Redis. */ - DBClientID id; - /** The number of heartbeat intervals that have passed since we last heard - * from this local scheduler. */ - int64_t num_heartbeats_missed; - /** The number of tasks sent from the global scheduler to this local - * scheduler. */ - int64_t num_tasks_sent; - /** The number of tasks sent from the global scheduler to this local scheduler - * since the last heartbeat arrived. */ - int64_t num_recent_tasks_sent; - /** The latest information about the local scheduler capacity. This is updated - * every time a new local scheduler heartbeat arrives. */ - LocalSchedulerInfo info; -} LocalScheduler; - -typedef class GlobalSchedulerPolicyState GlobalSchedulerPolicyState; - -/** - * This defines a hash table used to cache information about different objects. - */ -typedef struct { - /** The size in bytes of the object. */ - int64_t data_size; - /** A vector of object locations for this object. */ - std::vector object_locations; -} SchedulerObjectInfo; - -/** - * Global scheduler state structure. - */ -typedef struct { - /** The global scheduler event loop. */ - event_loop *loop; - /** The global state store database. */ - DBHandle *db; - /** A hash table mapping local scheduler ID to the local schedulers that are - * connected to Redis. */ - std::unordered_map local_schedulers; - /** The state managed by the scheduling policy. */ - GlobalSchedulerPolicyState *policy_state; - /** The plasma_manager ip:port -> local_scheduler_db_client_id association. */ - std::unordered_map plasma_local_scheduler_map; - /** The local_scheduler_db_client_id -> plasma_manager ip:port association. */ - std::unordered_map local_scheduler_plasma_map; - /** Objects cached by this global scheduler instance. */ - std::unordered_map scheduler_object_info_table; - /** An array of tasks that haven't been scheduled yet. */ - std::vector pending_tasks; -} GlobalSchedulerState; - -/** - * This is a helper method to look up the local scheduler struct that - * corresponds to a particular local_scheduler_id. - * - * @param state The state of the global scheduler. - * @param The local_scheduler_id of the local scheduler. - * @return The corresponding local scheduler struct. If the global scheduler is - * not aware of the local scheduler, then this will be NULL. - */ -LocalScheduler *get_local_scheduler(GlobalSchedulerState *state, - DBClientID local_scheduler_id); - -/** - * Assign the given task to the local scheduler, update Redis and scheduler data - * structures. - * - * @param state Global scheduler state. - * @param task Task to be assigned to the local scheduler. - * @param local_scheduler_id DB client ID for the local scheduler. - * @return Void. - */ -void assign_task_to_local_scheduler(GlobalSchedulerState *state, - Task *task, - DBClientID local_scheduler_id); - -#endif /* GLOBAL_SCHEDULER_H */ diff --git a/src/global_scheduler/global_scheduler_algorithm.cc b/src/global_scheduler/global_scheduler_algorithm.cc deleted file mode 100644 index 7ca1b86be9148..0000000000000 --- a/src/global_scheduler/global_scheduler_algorithm.cc +++ /dev/null @@ -1,257 +0,0 @@ -#include - -#include "task.h" -#include "state/task_table.h" - -#include "global_scheduler_algorithm.h" - -GlobalSchedulerPolicyState *GlobalSchedulerPolicyState_init(void) { - GlobalSchedulerPolicyState *policy_state = new GlobalSchedulerPolicyState(); - return policy_state; -} - -void GlobalSchedulerPolicyState_free(GlobalSchedulerPolicyState *policy_state) { - delete policy_state; -} - -/** - * Checks if the given local scheduler satisfies the task's hard constraints. - * - * @param scheduler Local scheduler. - * @param spec Task specification. - * @return True if all tasks's resource constraints are satisfied. False - * otherwise. - */ -bool constraints_satisfied_hard(const LocalScheduler *scheduler, - const TaskSpec *spec) { - if (scheduler->info.static_resources.count("CPU") == 1 && - scheduler->info.static_resources.at("CPU") == 0) { - // Don't give tasks to local schedulers that have 0 CPUs. This can be an - // issue for actor creation tasks that require 0 CPUs (but the subsequent - // actor methods require some CPUs). - return false; - } - - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Continue on if the task doesn't actually require this resource. - if (resource_quantity == 0) { - continue; - } - - // Check if the local scheduler has this resource. - if (scheduler->info.static_resources.count(resource_name) == 0) { - return false; - } - - // Check if the local scheduler has enough of the resource. - if (scheduler->info.static_resources.at(resource_name) < - resource_quantity) { - return false; - } - } - return true; -} - -int64_t locally_available_data_size(const GlobalSchedulerState *state, - DBClientID local_scheduler_id, - TaskSpec *task_spec) { - /* This function will compute the total size of all the object dependencies - * for the given task that are already locally available to the specified - * local scheduler. */ - int64_t task_data_size = 0; - - RAY_CHECK(state->local_scheduler_plasma_map.count(local_scheduler_id) == 1); - - const std::string &plasma_manager = - state->local_scheduler_plasma_map.at(local_scheduler_id); - - /* TODO(rkn): Note that if the same object ID appears as multiple arguments, - * then it will be overcounted. */ - for (int64_t i = 0; i < TaskSpec_num_args(task_spec); ++i) { - int count = TaskSpec_arg_id_count(task_spec, i); - for (int j = 0; j < count; ++j) { - ObjectID object_id = TaskSpec_arg_id(task_spec, i, j); - - if (state->scheduler_object_info_table.count(object_id) == 0) { - /* If this global scheduler is not aware of this object ID, then ignore - * it. */ - continue; - } - - const SchedulerObjectInfo &object_size_info = - state->scheduler_object_info_table.at(object_id); - - if (std::find(object_size_info.object_locations.begin(), - object_size_info.object_locations.end(), plasma_manager) == - object_size_info.object_locations.end()) { - /* This local scheduler does not have access to this object, so don't - * count this object. */ - continue; - } - - /* Look at the size of the object. */ - int64_t object_size = object_size_info.data_size; - if (object_size == -1) { - /* This means that this global scheduler does not know the object size - * yet, so assume that the object is one megabyte. TODO(rkn): Maybe we - * should instead use the average object size. */ - object_size = 1000000; - } - - /* If we get here, then this local scheduler has access to this object, so - * count the contribution of this object. */ - task_data_size += object_size; - } - } - - return task_data_size; -} - -double calculate_cost_pending(const GlobalSchedulerState *state, - const LocalScheduler *scheduler, - TaskSpec *task_spec) { - /* Calculate how much data is already present on this machine. TODO(rkn): Note - * that this information is not being used yet. Fix this. */ - locally_available_data_size(state, scheduler->id, task_spec); - /* TODO(rkn): This logic does not load balance properly when the different - * machines have different sizes. Fix this. */ - double cost_pending = scheduler->num_recent_tasks_sent + - scheduler->info.task_queue_length - - scheduler->info.available_workers; - return cost_pending; -} - -bool handle_task_waiting_random(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - TaskSpec *task_spec = Task_task_execution_spec(task)->Spec(); - RAY_CHECK(task_spec != NULL) - << "task wait handler encounted a task with NULL spec"; - - std::vector feasible_nodes; - - for (const auto &it : state->local_schedulers) { - // Local scheduler map iterator yields pairs. - const LocalScheduler &local_scheduler = it.second; - if (!constraints_satisfied_hard(&local_scheduler, task_spec)) { - continue; - } - // Add this local scheduler as a candidate for random selection. - feasible_nodes.push_back(it.first); - } - - if (feasible_nodes.size() == 0) { - RAY_LOG(ERROR) << "Infeasible task. No nodes satisfy hard constraints for " - << "task = " << Task_task_id(task); - return false; - } - - // Randomly select the local scheduler. TODO(atumanov): replace with - // std::discrete_distribution. - std::uniform_int_distribution<> dis(0, feasible_nodes.size() - 1); - DBClientID local_scheduler_id = - feasible_nodes[dis(policy_state->getRandomGenerator())]; - RAY_CHECK(!local_scheduler_id.is_nil()) - << "Task is feasible, but doesn't have a local scheduler assigned."; - // A local scheduler ID was found, so assign the task. - assign_task_to_local_scheduler(state, task, local_scheduler_id); - return true; -} - -bool handle_task_waiting_cost(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - TaskSpec *task_spec = Task_task_execution_spec(task)->Spec(); - int64_t curtime = current_time_ms(); - - RAY_CHECK(task_spec != NULL) - << "task wait handler encounted a task with NULL spec"; - - // For tasks already seen by the global scheduler (spillback > 1), - // adjust scheduled task counts for the source local scheduler. - if (task->execution_spec->SpillbackCount() > 1) { - auto it = state->local_schedulers.find(task->local_scheduler_id); - // Task's previous local scheduler must be present and known. - RAY_CHECK(it != state->local_schedulers.end()); - LocalScheduler &src_local_scheduler = it->second; - src_local_scheduler.num_recent_tasks_sent -= 1; - } - - bool task_feasible = false; - - // Go through all the nodes, calculate the score for each, pick max score. - double best_local_scheduler_score = INT32_MIN; - RAY_CHECK(best_local_scheduler_score < 0) - << "We might have a floating point underflow"; - RAY_LOG(INFO) << "ct[" << curtime << "] task from " - << task->local_scheduler_id << " spillback " - << task->execution_spec->SpillbackCount(); - - // The best node to send this task. - DBClientID best_local_scheduler_id = DBClientID::nil(); - - for (auto it = state->local_schedulers.begin(); - it != state->local_schedulers.end(); it++) { - // For each local scheduler, calculate its score. Check hard constraints - // first. - LocalScheduler *scheduler = &(it->second); - if (!constraints_satisfied_hard(scheduler, task_spec)) { - continue; - } - // Skip the local scheduler the task came from. - if (task->local_scheduler_id == scheduler->id) { - continue; - } - task_feasible = true; - // This node satisfies the hard capacity constraint. Calculate its score. - double score = -1 * calculate_cost_pending(state, scheduler, task_spec); - RAY_LOG(INFO) << "ct[" << curtime << "][" << scheduler->id << "][q" - << scheduler->info.task_queue_length << "][w" - << scheduler->info.available_workers << "]: score " << score - << " bestscore " << best_local_scheduler_score; - if (score >= best_local_scheduler_score) { - best_local_scheduler_score = score; - best_local_scheduler_id = scheduler->id; - } - } - - if (!task_feasible) { - RAY_LOG(ERROR) << "Infeasible task. No nodes satisfy hard constraints for " - << "task = " << Task_task_id(task); - // TODO(atumanov): propagate this error to the task's driver and/or - // cache the task in case new local schedulers satisfy it in the future. - return false; - } - RAY_CHECK(!best_local_scheduler_id.is_nil()) - << "Task is feasible, but doesn't have a local scheduler assigned."; - // A local scheduler ID was found, so assign the task. - assign_task_to_local_scheduler(state, task, best_local_scheduler_id); - return true; -} - -bool handle_task_waiting(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - return handle_task_waiting_random(state, policy_state, task); -} - -void handle_object_available(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - ObjectID object_id) { - /* Do nothing for now. */ -} - -void handle_new_local_scheduler(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id) { - /* Do nothing for now. */ -} - -void handle_local_scheduler_removed(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id) { - /* Do nothing for now. */ -} diff --git a/src/global_scheduler/global_scheduler_algorithm.h b/src/global_scheduler/global_scheduler_algorithm.h deleted file mode 100644 index 69be67d97477d..0000000000000 --- a/src/global_scheduler/global_scheduler_algorithm.h +++ /dev/null @@ -1,126 +0,0 @@ -#ifndef GLOBAL_SCHEDULER_ALGORITHM_H -#define GLOBAL_SCHEDULER_ALGORITHM_H - -#include -#include - -#include "common.h" -#include "global_scheduler.h" -#include "task.h" - -/* ==== The scheduling algorithm ==== - * - * This file contains declaration for all functions and data structures that - * need to be provided if you want to implement a new algorithm for the global - * scheduler. - * - */ - -enum class GlobalSchedulerAlgorithm { - SCHED_ALGORITHM_ROUND_ROBIN = 1, - SCHED_ALGORITHM_TRANSFER_AWARE = 2, - SCHED_ALGORITHM_MAX -}; - -/// The class encapsulating state managed by the global scheduling policy. -class GlobalSchedulerPolicyState { - public: - GlobalSchedulerPolicyState(int64_t round_robin_index) - : round_robin_index_(round_robin_index), - gen_(std::chrono::high_resolution_clock::now() - .time_since_epoch() - .count()) {} - - GlobalSchedulerPolicyState() - : round_robin_index_(0), - gen_(std::chrono::high_resolution_clock::now() - .time_since_epoch() - .count()) {} - - /// Return the policy's random number generator. - /// - /// @return The policy's random number generator. - std::mt19937_64 &getRandomGenerator() { return gen_; } - - /// Return the round robin index maintained by policy state. - /// - /// @return The round robin index. - int64_t getRoundRobinIndex() const { return round_robin_index_; } - - private: - /// The index of the next local scheduler to assign a task to. - int64_t round_robin_index_; - /// Internally maintained random number generator. - std::mt19937_64 gen_; -}; - -/** - * Create the state of the global scheduler policy. This state must be freed by - * the caller. - * - * @return The state of the scheduling policy. - */ -GlobalSchedulerPolicyState *GlobalSchedulerPolicyState_init(void); - -/** - * Free the global scheduler policy state. - * - * @param policy_state The policy state to free. - * @return Void. - */ -void GlobalSchedulerPolicyState_free(GlobalSchedulerPolicyState *policy_state); - -/** - * Main new task handling function in the global scheduler. - * - * @param state Global scheduler state. - * @param policy_state State specific to the scheduling policy. - * @param task New task to be scheduled. - * @return True if the task was assigned to a local scheduler and false - * otherwise. - */ -bool handle_task_waiting(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task); - -/** - * Handle the fact that a new object is available. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @param object_id The ID of the object that is now available. - * @return Void. - */ -void handle_object_available(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - ObjectID object_id); - -/** - * Handle a heartbeat message from a local scheduler. TODO(rkn): this is a - * placeholder for now. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @return Void. - */ -void handle_local_scheduler_heartbeat(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state); - -/** - * Handle the presence of a new local scheduler. Currently, this just adds the - * local scheduler to a queue of local schedulers. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @param The db client ID of the new local scheduler. - * @return Void. - */ -void handle_new_local_scheduler(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id); - -void handle_local_scheduler_removed(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id); - -#endif /* GLOBAL_SCHEDULER_ALGORITHM_H */ diff --git a/src/local_scheduler/CMakeLists.txt b/src/local_scheduler/CMakeLists.txt deleted file mode 100644 index 7033c4f2306cf..0000000000000 --- a/src/local_scheduler/CMakeLists.txt +++ /dev/null @@ -1,104 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(local_scheduler) - -add_definitions(-fPIC) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${PYTHON_INCLUDE_DIRS}") - include_directories("${NUMPY_INCLUDE_DIR}") -endif() - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -if(UNIX AND NOT APPLE) - link_libraries(rt) -endif() - -include_directories("${CMAKE_CURRENT_LIST_DIR}/") -include_directories("${CMAKE_CURRENT_LIST_DIR}/../") -# TODO(pcm): get rid of this: -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${CMAKE_CURRENT_LIST_DIR}/../plasma/") -endif() - -include_directories("${ARROW_INCLUDE_DIR}") -include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/format/") - -# Compile flatbuffers - -set(LOCAL_SCHEDULER_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/local_scheduler.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(LOCAL_SCHEDULER_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/local_scheduler_generated.h") - -add_custom_command( - OUTPUT ${LOCAL_SCHEDULER_FBS_OUTPUT_FILES} - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${LOCAL_SCHEDULER_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${LOCAL_SCHEDULER_FBS_SRC}" - VERBATIM) - -add_custom_target(gen_local_scheduler_fbs DEPENDS ${LOCAL_SCHEDULER_FBS_OUTPUT_FILES}) - -add_dependencies(gen_local_scheduler_fbs arrow) - -add_library(local_scheduler_client STATIC local_scheduler_client.cc) - -# local_scheduler_shared.h includes ray/gcs/client.h which requires gen_gcs_fbs & gen_node_manager_fbs. -add_dependencies(local_scheduler_client common hiredis gen_local_scheduler_fbs ${COMMON_FBS_OUTPUT_FILES} gen_gcs_fbs gen_node_manager_fbs) - -add_executable(local_scheduler local_scheduler.cc local_scheduler_algorithm.cc) -add_dependencies(local_scheduler hiredis) -target_link_libraries(local_scheduler local_scheduler_client common ${HIREDIS_LIB} ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) - -add_executable(local_scheduler_tests test/local_scheduler_tests.cc local_scheduler.cc local_scheduler_algorithm.cc) -add_dependencies(local_scheduler_tests hiredis) -target_link_libraries(local_scheduler_tests local_scheduler_client common ${HIREDIS_LIB} ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) -target_compile_options(local_scheduler_tests PUBLIC "-DLOCAL_SCHEDULER_TEST") - -macro(get_local_scheduler_library LANG VAR) - set(${VAR} "local_scheduler_library_${LANG}") -endmacro() - -macro(set_local_scheduler_library LANG) - get_local_scheduler_library(${LANG} LOCAL_SCHEDULER_LIBRARY_${LANG}) - set(LOCAL_SCHEDULER_LIBRARY_LANG ${LOCAL_SCHEDULER_LIBRARY_${LANG}}) - include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/lib/${LANG}/") - - file(GLOB LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC - lib/${LANG}/*.cc - ${CMAKE_CURRENT_LIST_DIR}/../common/lib/${LANG}/*.cc) - add_library(${LOCAL_SCHEDULER_LIBRARY_LANG} SHARED - ${LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC}) - - if(APPLE) - if ("${LANG}" STREQUAL "python") - SET_TARGET_PROPERTIES(${LOCAL_SCHEDULER_LIBRARY_LANG} PROPERTIES SUFFIX .so) - endif() - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} "-undefined dynamic_lookup" local_scheduler_client common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) - else(APPLE) - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} local_scheduler_client common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) - endif(APPLE) - - add_dependencies(${LOCAL_SCHEDULER_LIBRARY_LANG} gen_local_scheduler_fbs) - - install(TARGETS ${LOCAL_SCHEDULER_LIBRARY_LANG} DESTINATION ${CMAKE_SOURCE_DIR}/local_scheduler) -endmacro() - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - set_local_scheduler_library("python") -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_compile_options("-I$ENV{JAVA_HOME}/include/") - if(WIN32) - add_compile_options("-I$ENV{JAVA_HOME}/include/win32") - elseif(APPLE) - add_compile_options("-I$ENV{JAVA_HOME}/include/darwin") - else() # linux - add_compile_options("-I$ENV{JAVA_HOME}/include/linux") - endif() - set_local_scheduler_library("java") -endif() diff --git a/src/local_scheduler/build/.gitkeep b/src/local_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs deleted file mode 100644 index ffdf13d6aea41..0000000000000 --- a/src/local_scheduler/format/local_scheduler.fbs +++ /dev/null @@ -1,127 +0,0 @@ -// Local scheduler protocol specification -namespace ray.local_scheduler.protocol; - -enum MessageType:int { - // Task is submitted to the local scheduler. This is sent from a worker to a - // local scheduler. - SubmitTask = 1, - // Notify the local scheduler that a task has finished. This is sent from a - // worker to a local scheduler. - TaskDone, - // Log a message to the event table. This is sent from a worker to a local - // scheduler. - EventLogMessage, - // Send an initial connection message to the local scheduler. This is sent - // from a worker or driver to a local scheduler. - RegisterClientRequest, - // Send a reply confirming the successful registration of a worker or driver. - // This is sent from the local scheduler to a worker or driver. - RegisterClientReply, - // Notify the local scheduler that this client is disconnecting gracefully. - // This is sent from a worker to a local scheduler. - DisconnectClient, - // Get a new task from the local scheduler. This is sent from a worker to a - // local scheduler. - GetTask, - // Tell a worker to execute a task. This is sent from a local scheduler to a - // worker. - ExecuteTask, - // Reconstruct or fetch possibly lost objects. This is sent from a worker to - // a local scheduler. - ReconstructObjects, - // For a worker that was blocked on some object(s), tell the local scheduler - // that the worker is now unblocked. This is sent from a worker to a local - // scheduler. - NotifyUnblocked, - // Add a result table entry for an object put. - PutObject, - // A request to get the task frontier for an actor, called by the actor when - // saving a checkpoint. - GetActorFrontierRequest, - // The ActorFrontier response to a GetActorFrontierRequest. The local - // scheduler returns the actor's per-handle task counts and execution - // dependencies, which can later be used as the argument to SetActorFrontier - // when resuming from the checkpoint. - GetActorFrontierReply, - // A request to set the task frontier for an actor, called when resuming from - // a checkpoint. The local scheduler will update the actor's per-handle task - // counts and execution dependencies, discard any tasks that already executed - // before the checkpoint, and make any tasks on the frontier runnable by - // making their execution dependencies available. - SetActorFrontier -} - -table SubmitTaskRequest { - execution_dependencies: [string]; - task_spec: string; -} - -// This message is sent from the local scheduler to a worker. -table GetTaskReply { - // A string of bytes representing the task specification. - task_spec: string; - // The IDs of the GPUs that the worker is allowed to use for this task. - gpu_ids: [int]; -} - -table EventLogMessage { - key: string; - value: string; - timestamp: double; -} - -// This struct is used to register a new worker with the local scheduler. -// It is shipped as part of local_scheduler_connect. -table RegisterClientRequest { - // True if the client is a worker and false if the client is a driver. - is_worker: bool; - // The ID of the worker or driver. - client_id: string; - // The process ID of this worker. - worker_pid: long; - // The driver ID. This is non-nil if the client is a driver. - driver_id: string; -} - -table DisconnectClient { -} - -table ReconstructObjects { - // List of object IDs of the objects that we want to reconstruct or fetch. - object_ids: [string]; - // Do we only want to fetch the objects or also reconstruct them? - fetch_only: bool; -} - -table PutObject { - // Task ID of the task that performed the put. - task_id: string; - // Object ID of the object that is being put. - object_id: string; -} - -// The ActorFrontier is used to represent the current frontier of tasks that -// the local scheduler has marked as runnable for a particular actor. It is -// used to save the point in an actor's lifetime at which a checkpoint was -// taken, so that the same frontier of tasks can be made runnable again if the -// actor is resumed from that checkpoint. -table ActorFrontier { - // Actor ID of the actor whose frontier is described. - actor_id: string; - // A list of handle IDs, representing the callers of the actor that have - // submitted a runnable task to the local scheduler. A nil ID represents the - // creator of the actor. - handle_ids: [string]; - // A list representing the number of tasks executed so far, per handle. Each - // count in task_counters corresponds to the handle at the same in index in - // handle_ids. - task_counters: [long]; - // A list representing the execution dependency for the next runnable task, - // per handle. Each execution dependency in frontier_dependencies corresponds - // to the handle at the same in index in handle_ids. - frontier_dependencies: [string]; -} - -table GetActorFrontierRequest { - actor_id: string; -} diff --git a/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.h deleted file mode 100644 index b730b00643d54..0000000000000 --- a/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ /dev/null @@ -1,134 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_raylet_RayletClientImpl */ - -#ifndef _Included_org_ray_runtime_raylet_RayletClientImpl -#define _Included_org_ray_runtime_raylet_RayletClientImpl -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeInit - * Signature: (Ljava/lang/String;[BZ[B)J - */ -JNIEXPORT jlong JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(JNIEnv *, - jclass, - jstring, - jbyteArray, - jboolean, - jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSubmitTask - * Signature: (J[BLjava/nio/ByteBuffer;II)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(JNIEnv *, - jclass, - jlong, - jbyteArray, - jobject, - jint, - jint); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGetTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeReconstructObjects - * Signature: (J[[BZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeReconstructObjects( - JNIEnv *, - jclass, - jlong, - jobjectArray, - jboolean); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyUnblocked - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePutObject - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativePutObject(JNIEnv *, - jclass, - jlong, - jbyteArray, - jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeWaitObject - * Signature: (J[[BIIZ)[Z - */ -JNIEXPORT jbooleanArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(JNIEnv *, - jclass, - jlong, - jobjectArray, - jint, - jint, - jboolean); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(JNIEnv *, - jclass, - jbyteArray, - jbyteArray, - jint); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFreePlasmaObjects - * Signature: (J[[BZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( - JNIEnv *, - jclass, - jlong, - jobjectArray, - jboolean); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc deleted file mode 100644 index 7bef00993ab9b..0000000000000 --- a/src/local_scheduler/local_scheduler.cc +++ /dev/null @@ -1,1555 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "common.h" -#include "common_protocol.h" -#include "event_loop.h" -#include "format/local_scheduler_generated.h" -#include "io.h" -#include "local_scheduler.h" -#include "local_scheduler_algorithm.h" -#include "local_scheduler_shared.h" -#include "logging.h" -#include "net.h" -#include "ray/util/util.h" -#include "state/actor_notification_table.h" -#include "state/db.h" -#include "state/db_client_table.h" -#include "state/driver_table.h" -#include "state/error_table.h" -#include "state/object_table.h" -#include "state/task_table.h" - -using MessageType = ray::local_scheduler::protocol::MessageType; - -/** - * A helper function for printing available and requested resource information. - * - * @param state Local scheduler state. - * @param spec Task specification object. - * @return Void. - */ -void print_resource_info(const LocalSchedulerState *state, - const TaskSpec *spec) { -#if RAY_COMMON_LOG_LEVEL <= RAY_COMMON_DEBUG - // Print information about available and requested resources. - std::cout << "Static Resources: " << std::endl; - for (auto const &resource_pair : state->static_resources) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - std::cout << "Dynamic Resources: " << std::endl; - for (auto const &resource_pair : state->dynamic_resources) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - if (spec) { - std::cout << "Task Required Resources: " << std::endl; - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - } -#endif -} - -int force_kill_worker(event_loop *loop, timer_id id, void *context) { - LocalSchedulerClient *worker = (LocalSchedulerClient *) context; - kill(worker->pid, SIGKILL); - close(worker->sock); - delete worker; - return EVENT_LOOP_TIMER_DONE; -} - -void kill_worker(LocalSchedulerState *state, - LocalSchedulerClient *worker, - bool cleanup, - bool suppress_warning) { - /* Erase the local scheduler's reference to the worker. */ - auto it = std::find(state->workers.begin(), state->workers.end(), worker); - RAY_CHECK(it != state->workers.end()); - state->workers.erase(it); - - /* Make sure that we removed the worker. */ - it = std::find(state->workers.begin(), state->workers.end(), worker); - RAY_CHECK(it == state->workers.end()); - - /* Release any resources held by the worker. It's important to do this before - * calling handle_worker_removed and handle_actor_worker_disconnect because - * freeing up resources here will allow the scheduling algorithm to dispatch - * more tasks. */ - release_resources(state, worker, worker->resources_in_use); - - /* Erase the algorithm state's reference to the worker. */ - if (worker->actor_id.is_nil()) { - handle_worker_removed(state, state->algorithm_state, worker); - } else { - /* Let the scheduling algorithm process the absence of this worker. */ - handle_actor_worker_disconnect(state, state->algorithm_state, worker, - cleanup); - } - - /* Remove the client socket from the event loop so that we don't process the - * SIGPIPE when the worker is killed. */ - event_loop_remove_file(state->loop, worker->sock); - - /* If the worker has registered a process ID with us and it's a child - * process, use it to send a kill signal. */ - bool free_worker = true; - if (worker->is_child && worker->pid != 0) { - /* If worker is a driver, we should not enter this condition because - * worker->pid should be 0. */ - if (cleanup) { - /* If we're exiting the local scheduler anyway, it's okay to force kill - * the worker immediately. Wait for the process to exit. */ - kill(worker->pid, SIGKILL); - waitpid(worker->pid, NULL, 0); - close(worker->sock); - } else { - /* If we're just cleaning up a single worker, allow it some time to clean - * up its state before force killing. The client socket will be closed - * and the worker struct will be freed after the timeout. */ - kill(worker->pid, SIGTERM); - event_loop_add_timer( - state->loop, RayConfig::instance().kill_worker_timeout_milliseconds(), - force_kill_worker, (void *) worker); - free_worker = false; - } - RAY_LOG(DEBUG) << "Killed worker with pid " << worker->pid; - } - - /* If this worker is still running a task and we aren't cleaning up, push an - * error message to the driver responsible for the task. */ - if (worker->task_in_progress != NULL && !cleanup && !suppress_warning) { - TaskSpec *spec = Task_task_execution_spec(worker->task_in_progress)->Spec(); - - std::ostringstream error_message; - error_message << "The worker with ID " << worker->client_id << " died or " - << "was killed while executing the task with ID " - << TaskSpec_task_id(spec); - push_error(state->db, TaskSpec_driver_id(spec), ErrorIndex::WORKER_DIED, - error_message.str()); - } - - /* Clean up the task in progress. */ - if (worker->task_in_progress) { - /* Update the task table to reflect that the task failed to complete. */ - if (state->db != NULL) { - Task_set_state(worker->task_in_progress, TaskStatus::LOST); - task_table_update(state->db, worker->task_in_progress, NULL, NULL, NULL); - } else { - Task_free(worker->task_in_progress); - } - } - - RAY_LOG(DEBUG) << "Killed worker with pid " << worker->pid; - if (free_worker) { - /* Clean up the client socket after killing the worker so that the worker - * can't receive the SIGPIPE before exiting. */ - close(worker->sock); - delete worker; - } -} - -void LocalSchedulerState_free(LocalSchedulerState *state) { - /* Reset the SIGTERM handler to default behavior, so we try to clean up the - * local scheduler at most once. If a SIGTERM is caught afterwards, there is - * the possibility of orphan worker processes. */ - signal(SIGTERM, SIG_DFL); - /* Send a null heartbeat that tells the global scheduler that we are dead to - * avoid waiting for the heartbeat timeout. */ - if (state->db != NULL) { - local_scheduler_table_disconnect(state->db); - } - - /* Kill any child processes that didn't register as a worker yet. */ - for (auto const &worker_pid : state->child_pids) { - kill(worker_pid, SIGKILL); - waitpid(worker_pid, NULL, 0); - RAY_LOG(INFO) << "Killed worker pid " << worker_pid - << " which hadn't started yet."; - } - - /* Kill any registered workers. */ - /* TODO(swang): It's possible that the local scheduler will exit before all - * of its task table updates make it to redis. */ - while (state->workers.size() > 0) { - /* Note that kill_worker modifies the container state->workers, so it is - * important to do this loop in a way that does not use invalidated - * iterators. */ - kill_worker(state, state->workers.back(), true, false); - } - - /* Disconnect from plasma. */ - ARROW_CHECK_OK(state->plasma_conn->Disconnect()); - delete state->plasma_conn; - state->plasma_conn = NULL; - - /* Clean up the database connection. NOTE(swang): The global scheduler is - * responsible for deleting our entry from the db_client table, so do not - * delete it here. */ - if (state->db != NULL) { - DBHandle_free(state->db); - } - - /* Free the command for starting new workers. */ - if (state->config.start_worker_command != NULL) { - int i = 0; - const char *arg = state->config.start_worker_command[i]; - while (arg != NULL) { - free((void *) arg); - ++i; - arg = state->config.start_worker_command[i]; - } - free(state->config.start_worker_command); - state->config.start_worker_command = NULL; - } - - /* Free the algorithm state. */ - SchedulingAlgorithmState_free(state->algorithm_state); - state->algorithm_state = NULL; - - event_loop *loop = state->loop; - - /* Free the scheduler state. */ - delete state; - - /* Destroy the event loop. */ - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); -} - -void start_worker(LocalSchedulerState *state) { - /* We can't start a worker if we don't have the path to the worker script. */ - if (state->config.start_worker_command == NULL) { - RAY_LOG(DEBUG) << "No valid command to start worker provided. Cannot start " - << "worker."; - return; - } - /* Launch the process to create the worker. */ - pid_t pid = fork(); - if (pid != 0) { - state->child_pids.push_back(pid); - RAY_LOG(DEBUG) << "Started worker with pid " << pid; - return; - } - - /* Reset the SIGCHLD handler so that it doesn't influence the worker. */ - signal(SIGCHLD, SIG_DFL); - - std::vector command_vector; - for (int i = 0; state->config.start_worker_command[i] != NULL; i++) { - command_vector.push_back(state->config.start_worker_command[i]); - } - - /* Add a NULL pointer to the end. */ - command_vector.push_back(NULL); - - /* Try to execute the worker command. Exit if we're not successful. */ - execvp(command_vector[0], (char *const *) command_vector.data()); - - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Failed to start worker"; -} - -/** - * Parse the command to start a worker. This takes in the command string, - * splits it into tokens on the space characters, and allocates an array of the - * tokens, terminated by a NULL pointer. - * - * @param command The command string to start a worker. - * @return A pointer to an array of strings, the tokens in the command string. - * The last element is a NULL pointer. - */ -const char **parse_command(const char *command) { - /* Count the number of tokens. */ - char *command_copy = strdup(command); - const char *delimiter = " "; - char *token = NULL; - int num_args = 0; - token = strtok(command_copy, delimiter); - while (token != NULL) { - ++num_args; - token = strtok(NULL, delimiter); - } - free(command_copy); - - /* Allocate a NULL-terminated array for the tokens. */ - const char **command_args = - (const char **) malloc((num_args + 1) * sizeof(const char *)); - command_args[num_args] = NULL; - - /* Fill in the token array. */ - command_copy = strdup(command); - token = strtok(command_copy, delimiter); - int i = 0; - while (token != NULL) { - command_args[i] = strdup(token); - ++i; - token = strtok(NULL, delimiter); - } - free(command_copy); - - RAY_CHECK(num_args == i); - return command_args; -} - -LocalSchedulerState *LocalSchedulerState_init( - const char *node_ip_address, - event_loop *loop, - const char *redis_primary_addr, - int redis_primary_port, - const char *local_scheduler_socket_name, - const char *plasma_store_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_conf, - const char *start_worker_command, - int num_workers) { - LocalSchedulerState *state = new LocalSchedulerState(); - /* Set the configuration struct for the local scheduler. */ - if (start_worker_command != NULL) { - state->config.start_worker_command = parse_command(start_worker_command); - } else { - state->config.start_worker_command = NULL; - } - if (start_worker_command == NULL) { - RAY_LOG(WARNING) << "No valid command to start a worker provided, local " - << "scheduler will not start any workers."; - } - state->config.global_scheduler_exists = global_scheduler_exists; - - state->loop = loop; - - /* Connect to Redis if a Redis address is provided. */ - if (redis_primary_addr != NULL) { - /* Construct db_connect_args */ - std::vector db_connect_args; - db_connect_args.push_back("local_scheduler_socket_name"); - db_connect_args.push_back(local_scheduler_socket_name); - for (auto const &resource_pair : static_resource_conf) { - // TODO(rkn): This could cause issues if a resource name collides with - // another field name "manager_address". - db_connect_args.push_back(resource_pair.first); - db_connect_args.push_back(std::to_string(resource_pair.second)); - } - - if (plasma_manager_address != NULL) { - db_connect_args.push_back("manager_address"); - db_connect_args.push_back(plasma_manager_address); - } - - state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, - "local_scheduler", node_ip_address, db_connect_args); - db_attach(state->db, loop, false); - } else { - state->db = NULL; - } - /* Connect to Plasma. This method will retry if Plasma hasn't started yet. */ - state->plasma_conn = new plasma::PlasmaClient(); - if (plasma_manager_socket_name != NULL) { - ARROW_CHECK_OK(state->plasma_conn->Connect( - plasma_store_socket_name, plasma_manager_socket_name, - plasma::kPlasmaDefaultReleaseDelay)); - } else { - ARROW_CHECK_OK(state->plasma_conn->Connect( - plasma_store_socket_name, "", plasma::kPlasmaDefaultReleaseDelay)); - } - /* Subscribe to notifications about sealed objects. */ - int plasma_fd; - ARROW_CHECK_OK(state->plasma_conn->Subscribe(&plasma_fd)); - /* Add the callback that processes the notification to the event loop. */ - event_loop_add_file(loop, plasma_fd, EVENT_LOOP_READ, - process_plasma_notification, state); - /* Add scheduler state. */ - state->algorithm_state = SchedulingAlgorithmState_init(); - - /* Initialize resource vectors. */ - state->static_resources = static_resource_conf; - state->dynamic_resources = static_resource_conf; - /* Initialize available GPUs. */ - if (state->static_resources.count("GPU") == 1) { - for (int i = 0; i < state->static_resources["GPU"]; ++i) { - state->available_gpus.push_back(i); - } - } - /* Print some debug information about resource configuration. */ - print_resource_info(state, NULL); - - /* Start the initial set of workers. */ - for (int i = 0; i < num_workers; ++i) { - start_worker(state); - } - - /* Initialize the time at which the previous heartbeat was sent. */ - state->previous_heartbeat_time = current_time_ms(); - - return state; -} - -/* TODO(atumanov): vectorize resource counts on input. */ -bool check_dynamic_resources( - LocalSchedulerState *state, - const std::unordered_map &resources) { - for (auto const &resource_pair : resources) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - if (state->dynamic_resources[resource_name] < resource_quantity) { - return false; - } - } - return true; -} - -void resource_sanity_checks(LocalSchedulerState *state, - LocalSchedulerClient *worker) { - // Check the resources in use by the worker. - for (auto const &resource_pair : worker->resources_in_use) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - RAY_CHECK(state->dynamic_resources[resource_name] <= - state->static_resources[resource_name]); - if (resource_name != std::string("CPU")) { - RAY_CHECK(state->dynamic_resources[resource_name] >= 0); - } - - RAY_CHECK(resource_quantity >= 0); - RAY_CHECK(resource_quantity <= state->static_resources[resource_name]); - } -} - -/* TODO(atumanov): just pass the required resource vector of doubles. */ -void acquire_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources) { - // Loop over each required resource type and acquire the appropriate quantity. - for (auto const &resource_pair : resources) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Do some special handling for GPU resources. - if (resource_name == std::string("GPU")) { - if (resource_quantity != 0) { - // Make sure that the worker isn't using any GPUs already. - RAY_CHECK(worker->gpus_in_use.size() == 0); - RAY_CHECK(state->available_gpus.size() >= resource_quantity); - // Reserve GPUs for the worker. - for (int i = 0; i < resource_quantity; i++) { - worker->gpus_in_use.push_back(state->available_gpus.back()); - state->available_gpus.pop_back(); - } - } - } - - // Do bookkeeping for general resource types. - if (resource_name != std::string("CPU")) { - RAY_CHECK(state->dynamic_resources[resource_name] >= resource_quantity); - } - state->dynamic_resources[resource_name] -= resource_quantity; - worker->resources_in_use[resource_name] += resource_quantity; - } - - // Do some sanity checks. - resource_sanity_checks(state, worker); -} - -void release_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources) { - for (auto const &resource_pair : resources) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Do some special handling for GPU resources. - if (resource_name == std::string("GPU")) { - if (resource_quantity != 0) { - RAY_CHECK(resource_quantity == worker->gpus_in_use.size()); - // Move the GPU IDs the worker was using back to the local scheduler. - for (auto const &gpu_id : worker->gpus_in_use) { - state->available_gpus.push_back(gpu_id); - } - worker->gpus_in_use.clear(); - } - } - - // Do bookkeeping for general resources types. - state->dynamic_resources[resource_name] += resource_quantity; - worker->resources_in_use[resource_name] -= resource_quantity; - } - - // Do some sanity checks. - resource_sanity_checks(state, worker); -} - -bool is_driver_alive(LocalSchedulerState *state, WorkerID driver_id) { - return state->removed_drivers.count(driver_id) == 0; -} - -void assign_task_to_worker(LocalSchedulerState *state, - TaskExecutionSpec &execution_spec, - LocalSchedulerClient *worker) { - int64_t task_spec_size = execution_spec.SpecSize(); - TaskSpec *spec = execution_spec.Spec(); - // Acquire the necessary resources for running this task. - const std::unordered_map required_resources = - TaskSpec_get_required_resources(spec); - acquire_resources(state, worker, required_resources); - // Check that actor tasks don't have non-CPU requirements. Any necessary - // non-CPU resources (in particular, GPUs) should already have been acquired - // by the actor worker. - if (!worker->actor_id.is_nil()) { - RAY_CHECK(required_resources.size() == 1); - RAY_CHECK(required_resources.count("CPU") == 1); - } - - RAY_CHECK(worker->actor_id == TaskSpec_actor_id(spec)); - /* Make sure the driver for this task is still alive. */ - WorkerID driver_id = TaskSpec_driver_id(spec); - RAY_CHECK(is_driver_alive(state, driver_id)); - - /* Construct a flatbuffer object to send to the worker. */ - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateGetTaskReply( - fbb, fbb.CreateString((char *) spec, task_spec_size), - fbb.CreateVector(worker->gpus_in_use)); - fbb.Finish(message); - - if (write_message(worker->sock, - static_cast(MessageType::ExecuteTask), - fbb.GetSize(), (uint8_t *) fbb.GetBufferPointer()) < 0) { - if (errno == EPIPE || errno == EBADF) { - /* Something went wrong, so kill the worker. */ - kill_worker(state, worker, false, false); - RAY_LOG(WARNING) << "Failed to give task to worker on fd " << worker->sock - << ". The client may have hung up."; - } else { - RAY_LOG(FATAL) << "Failed to give task to client on fd " << worker->sock; - } - } - - Task *task = - Task_alloc(execution_spec, TaskStatus::RUNNING, - state->db ? get_db_client_id(state->db) : DBClientID::nil()); - /* Record which task this worker is executing. This will be freed in - * process_message when the worker sends a GetTask message to the local - * scheduler. */ - worker->task_in_progress = Task_copy(task); - /* Update the global task table. */ - if (state->db != NULL) { - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - Task_free(task); - } -} - -// This is used to allow task_table_update to fail. -void allow_task_table_update_failure(UniqueID id, - void *user_context, - void *user_data) {} - -void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker) { - if (worker->task_in_progress != NULL) { - TaskSpec *spec = Task_task_execution_spec(worker->task_in_progress)->Spec(); - // Return dynamic resources back for the task in progress. - if (TaskSpec_is_actor_creation_task(spec)) { - // Resources required by the actor creation task are acquired for the - // actor's lifetime, so don't return anything here. TODO(rkn): Should the - // actor creation task require 1 CPU in addition to any resources acquired - // for the lifetime of the actor? If not, then the local scheduler may - // schedule an arbitrary number of actor creation tasks concurrently (if - // they don't acquire any resources for their entire lifetime). In - // practice this will usually be rate-limited by the rate at which we can - // create new workers. - - ActorID actor_creation_id = TaskSpec_actor_creation_id(spec); - WorkerID driver_id = TaskSpec_driver_id(spec); - - // The driver must be alive because if the driver had been removed, then - // this worker would have been killed (because it was executing a task for - // the driver). - RAY_CHECK(is_driver_alive(state, driver_id)); - - // Update the worker struct with this actor ID. - RAY_CHECK(worker->actor_id.is_nil()); - worker->actor_id = actor_creation_id; - // Extract the initial execution dependency from the actor creation task. - RAY_CHECK(TaskSpec_num_returns(spec) == 1); - ObjectID initial_execution_dependency = TaskSpec_return(spec, 0); - // Let the scheduling algorithm process the presence of this new worker. - handle_convert_worker_to_actor(state, state->algorithm_state, - actor_creation_id, - initial_execution_dependency, worker); - // Publish the actor creation notification. The corresponding callback - // handle_actor_creation_callback will update state->actor_mapping. - publish_actor_creation_notification( - state->db, actor_creation_id, driver_id, get_db_client_id(state->db)); - } else if (worker->actor_id.is_nil()) { - // Return dynamic resources back for the task in progress. - RAY_CHECK(worker->resources_in_use["CPU"] == - TaskSpec_get_required_resource(spec, "CPU")); - // Return GPU resources. - RAY_CHECK(worker->gpus_in_use.size() == - TaskSpec_get_required_resource(spec, "GPU")); - release_resources(state, worker, worker->resources_in_use); - } else { - // Actor tasks should only specify CPU requirements. - RAY_CHECK(0 == TaskSpec_get_required_resource(spec, "GPU")); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - release_resources(state, worker, cpu_resources); - } - /* If we're connected to Redis, update tables. */ - if (state->db != NULL) { - /* Update control state tables. */ - TaskStatus task_state = TaskStatus::DONE; - Task_set_state(worker->task_in_progress, task_state); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = allow_task_table_update_failure, - }; - - // We allow this call to fail in case the driver has been removed and the - // task table entries have already been cleaned up by the monitor. - task_table_update(state->db, worker->task_in_progress, &retryInfo, NULL, - NULL); - } else { - Task_free(worker->task_in_progress); - } - /* The call to task_table_update takes ownership of the - * task_in_progress, so we set the pointer to NULL so it is not used. */ - worker->task_in_progress = NULL; - } -} - -void process_plasma_notification(event_loop *loop, - int client_sock, - void *context, - int events) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - /* Read the notification from Plasma. */ - uint8_t *notification = read_message_async(loop, client_sock); - if (!notification) { - /* The store has closed the socket. */ - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Lost connection to the plasma store, local scheduler is " - << "exiting!"; - } - auto object_info = flatbuffers::GetRoot(notification); - ObjectID object_id = from_flatbuf(*object_info->object_id()); - if (object_info->is_deletion()) { - handle_object_removed(state, object_id); - } else { - handle_object_available(state, state->algorithm_state, object_id); - } - free(notification); -} - -void reconstruct_task_update_callback(Task *task, - void *user_context, - bool updated) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - if (!updated) { - /* The test-and-set failed. The task is either: (1) not finished yet, (2) - * lost, but not yet updated, or (3) already being reconstructed. */ - DBClientID current_local_scheduler_id = Task_local_scheduler(task); - if (!current_local_scheduler_id.is_nil()) { - DBClient current_local_scheduler = - db_client_table_cache_get(state->db, current_local_scheduler_id); - if (!current_local_scheduler.is_alive) { - /* (2) The current local scheduler for the task is dead. The task is - * lost, but the task table hasn't received the update yet. Retry the - * test-and-set. */ - task_table_test_and_update(state->db, Task_task_id(task), - current_local_scheduler_id, Task_state(task), - TaskStatus::RECONSTRUCTING, NULL, - reconstruct_task_update_callback, state); - } - } - /* The test-and-set failed, so it is not safe to resubmit the task for - * execution. Suppress the request. */ - return; - } - - /* Otherwise, the test-and-set succeeded, so resubmit the task for execution - * to ensure that reconstruction will happen. */ - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - if (TaskSpec_actor_id(spec).is_nil()) { - handle_task_submitted(state, state->algorithm_state, *execution_spec); - } else { - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - } - - /* Recursively reconstruct the task's inputs, if necessary. */ - int64_t num_dependencies = execution_spec->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = execution_spec->DependencyIdCount(i); - for (int64_t j = 0; j < count; ++j) { - ObjectID dependency_id = execution_spec->DependencyId(i, j); - reconstruct_object(state, dependency_id); - } - } -} - -void reconstruct_put_task_update_callback(Task *task, - void *user_context, - bool updated) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - if (!updated) { - /* The test-and-set failed. The task is either: (1) not finished yet, (2) - * lost, but not yet updated, or (3) already being reconstructed. */ - DBClientID current_local_scheduler_id = Task_local_scheduler(task); - if (!current_local_scheduler_id.is_nil()) { - DBClient current_local_scheduler = - db_client_table_cache_get(state->db, current_local_scheduler_id); - if (!current_local_scheduler.is_alive) { - /* (2) The current local scheduler for the task is dead. The task is - * lost, but the task table hasn't received the update yet. Retry the - * test-and-set. */ - task_table_test_and_update(state->db, Task_task_id(task), - current_local_scheduler_id, Task_state(task), - TaskStatus::RECONSTRUCTING, NULL, - reconstruct_put_task_update_callback, state); - } else if (Task_state(task) == TaskStatus::RUNNING) { - /* (1) The task is still executing on a live node. The object created - * by `ray.put` was not able to be reconstructed, and the workload will - * likely hang. Push an error to the appropriate driver. */ - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - - std::ostringstream error_message; - error_message << "The task with ID " << TaskSpec_task_id(spec) - << " is still executing and so the object created by " - << "ray.put could not be reconstructed."; - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::PUT_RECONSTRUCTION, error_message.str()); - } - } else { - /* (1) The task is still executing and it is the driver task. We cannot - * restart the driver task, so the workload will hang. Push an error to - * the appropriate driver. */ - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - - std::ostringstream error_message; - error_message << "The task with ID " << TaskSpec_task_id(spec) - << " is a driver task and so the object created by ray.put " - << "could not be reconstructed."; - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::PUT_RECONSTRUCTION, error_message.str()); - } - } else { - /* The update to TaskStatus::RECONSTRUCTING succeeded, so continue with - * reconstruction as usual. */ - reconstruct_task_update_callback(task, user_context, updated); - } -} - -void reconstruct_evicted_result_lookup_callback(ObjectID reconstruct_object_id, - TaskID task_id, - bool is_put, - void *user_context) { - RAY_CHECK(!task_id.is_nil()) - << "No task information found for object during reconstruction"; - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - - task_table_test_and_update_callback done_callback; - if (is_put) { - /* If the evicted object was created through ray.put and the originating - * task - * is still executing, it's very likely that the workload will hang and the - * worker needs to be restarted. Else, the reconstruction behavior is the - * same as for other evicted objects */ - done_callback = reconstruct_put_task_update_callback; - } else { - done_callback = reconstruct_task_update_callback; - } - /* If there are no other instances of the task running, it's safe for us to - * claim responsibility for reconstruction. */ - task_table_test_and_update(state->db, task_id, DBClientID::nil(), - (TaskStatus::DONE | TaskStatus::LOST), - TaskStatus::RECONSTRUCTING, NULL, done_callback, - state); -} - -void reconstruct_failed_result_lookup_callback(ObjectID reconstruct_object_id, - TaskID task_id, - bool is_put, - void *user_context) { - if (task_id.is_nil()) { - /* NOTE(swang): For some reason, the result table update sometimes happens - * after this lookup returns, possibly due to concurrent clients. In most - * cases, this is okay because the initial execution is probably still - * pending, so for now, we log a warning and suppress reconstruction. */ - RAY_LOG(WARNING) << "No task information found for object during " - << "reconstruction (no object entry yet)"; - return; - } - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - /* If the task failed to finish, it's safe for us to claim responsibility for - * reconstruction. */ - task_table_test_and_update(state->db, task_id, DBClientID::nil(), - TaskStatus::LOST, TaskStatus::RECONSTRUCTING, NULL, - reconstruct_task_update_callback, state); -} - -void reconstruct_object_lookup_callback( - ObjectID reconstruct_object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context) { - RAY_LOG(DEBUG) << "Manager count was " << manager_ids.size(); - /* Only continue reconstruction if we find that the object doesn't exist on - * any nodes. NOTE: This codepath is not responsible for checking if the - * object table entry is up-to-date. */ - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - /* Look up the task that created the object in the result table. */ - if (never_created) { - /* If the object has not been created yet, we reconstruct the object if and - * only if the task that created the object failed to complete. */ - result_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_failed_result_lookup_callback, - (void *) state); - } else { - /* If the object has been created, filter out the dead plasma managers that - * have it. */ - size_t num_live_managers = 0; - for (auto manager_id : manager_ids) { - DBClient manager = db_client_table_cache_get(state->db, manager_id); - if (manager.is_alive) { - num_live_managers++; - } - } - /* If the object was created, but all plasma managers that had the object - * either evicted it or failed, we reconstruct the object if and only if - * there are no other instances of the task running. */ - if (num_live_managers == 0) { - result_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_evicted_result_lookup_callback, - (void *) state); - } - } -} - -void reconstruct_object(LocalSchedulerState *state, - ObjectID reconstruct_object_id) { - RAY_LOG(DEBUG) << "Starting reconstruction"; - /* If the object is locally available, no need to reconstruct. */ - if (object_locally_available(state->algorithm_state, reconstruct_object_id)) { - return; - } - /* Determine if reconstruction is necessary by checking if the object exists - * on a node. */ - RAY_CHECK(state->db != NULL); - object_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_object_lookup_callback, (void *) state); -} - -void handle_client_register( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const ray::local_scheduler::protocol::RegisterClientRequest *message) { - /* Make sure this worker hasn't already registered. */ - RAY_CHECK(!worker->registered); - worker->registered = true; - worker->is_worker = message->is_worker(); - RAY_CHECK(worker->client_id.is_nil()); - worker->client_id = from_flatbuf(*message->client_id()); - - /* Register the worker or driver. */ - if (worker->is_worker) { - /* Update the actor mapping with the actor ID of the worker (if an actor is - * running on the worker). */ - worker->pid = message->worker_pid(); - /* Register worker process id with the scheduler. */ - /* Determine if this worker is one of our child processes. */ - RAY_LOG(DEBUG) << "PID is " << worker->pid; - auto it = std::find(state->child_pids.begin(), state->child_pids.end(), - worker->pid); - if (it != state->child_pids.end()) { - /* If this worker is one of our child processes, mark it as a child so - * that we know that we can wait for the process to exit during - * cleanup. */ - worker->is_child = true; - state->child_pids.erase(it); - RAY_LOG(DEBUG) << "Found matching child pid " << worker->pid; - } - } else { - /* Register the driver. Currently we don't do anything here. */ - } -} - -void handle_driver_removed_callback(WorkerID driver_id, void *user_context) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - - /* Kill any actors that were created by the removed driver, and kill any - * workers that are currently running tasks from the dead driver. */ - auto it = state->workers.begin(); - while (it != state->workers.end()) { - /* Increment the iterator by one before calling kill_worker, because - * kill_worker will invalidate the iterator. Note that this requires - * knowledge of the particular container that we are iterating over (in this - * case it is a list). */ - auto next_it = it; - next_it++; - - ActorID actor_id = (*it)->actor_id; - Task *task = (*it)->task_in_progress; - - if (!actor_id.is_nil()) { - /* This is an actor. */ - RAY_CHECK(state->actor_mapping.count(actor_id) == 1); - if (state->actor_mapping[actor_id].driver_id == driver_id) { - /* This actor was created by the removed driver, so kill the actor. */ - RAY_LOG(DEBUG) << "Killing an actor for a removed driver."; - kill_worker(state, *it, false, true); - } - } else if (task != NULL) { - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - RAY_LOG(DEBUG) << "Killing a worker executing a task for a removed " - << "driver."; - kill_worker(state, *it, false, true); - } - } - - it = next_it; - } - - /* Add the driver to a list of dead drivers. */ - state->removed_drivers.insert(driver_id); - - /* Notify the scheduling algorithm that the driver has been removed. It should - * remove tasks for that driver from its data structures. */ - handle_driver_removed(state, state->algorithm_state, driver_id); -} - -void handle_client_disconnect(LocalSchedulerState *state, - LocalSchedulerClient *worker) { - if (!worker->registered || worker->is_worker) { - } else { - /* In this case, a driver is disconecting. */ - driver_table_send_driver_death(state->db, worker->client_id, NULL); - } - /* Suppress the warning message if the worker already disconnected. */ - kill_worker(state, worker, false, worker->disconnected); -} - -void handle_get_actor_frontier(LocalSchedulerState *state, - LocalSchedulerClient *worker, - ActorID actor_id) { - auto task_counters = - get_actor_task_counters(state->algorithm_state, actor_id); - auto frontier = get_actor_frontier(state->algorithm_state, actor_id); - - /* Build the ActorFrontier flatbuffer. */ - std::vector handle_vector; - std::vector task_counter_vector; - std::vector frontier_vector; - for (auto handle : task_counters) { - handle_vector.push_back(handle.first); - task_counter_vector.push_back(handle.second); - frontier_vector.push_back(frontier[handle.first]); - } - flatbuffers::FlatBufferBuilder fbb; - auto reply = ray::local_scheduler::protocol::CreateActorFrontier( - fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, handle_vector), - fbb.CreateVector(task_counter_vector), to_flatbuf(fbb, frontier_vector)); - fbb.Finish(reply); - /* Respond with the built ActorFrontier. */ - if (write_message(worker->sock, - static_cast(MessageType::GetActorFrontierReply), - fbb.GetSize(), (uint8_t *) fbb.GetBufferPointer()) < 0) { - if (errno == EPIPE || errno == EBADF) { - /* Something went wrong, so kill the worker. */ - kill_worker(state, worker, false, false); - RAY_LOG(WARNING) << "Failed to return actor frontier to worker on fd " - << worker->sock << ". The client may have hung up."; - } else { - RAY_LOG(FATAL) << "Failed to give task to client on fd " << worker->sock; - } - } -} - -void handle_set_actor_frontier( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - ray::local_scheduler::protocol::ActorFrontier const &frontier) { - /* Parse the ActorFrontier flatbuffer. */ - ActorID actor_id = from_flatbuf(*frontier.actor_id()); - std::unordered_map task_counters; - std::unordered_map frontier_dependencies; - for (size_t i = 0; i < frontier.handle_ids()->size(); ++i) { - ActorID handle_id = from_flatbuf(*frontier.handle_ids()->Get(i)); - task_counters[handle_id] = frontier.task_counters()->Get(i); - frontier_dependencies[handle_id] = - from_flatbuf(*frontier.frontier_dependencies()->Get(i)); - } - /* Set the actor's frontier. */ - set_actor_task_counters(state->algorithm_state, actor_id, task_counters); - set_actor_frontier(state, state->algorithm_state, actor_id, - frontier_dependencies); -} - -void process_message(event_loop *loop, - int client_sock, - void *context, - int events) { - int64_t start_time = current_time_ms(); - - LocalSchedulerClient *worker = (LocalSchedulerClient *) context; - LocalSchedulerState *state = worker->local_scheduler_state; - - int64_t type; - read_vector(client_sock, &type, state->input_buffer); - uint8_t *input = state->input_buffer.data(); - - RAY_LOG(DEBUG) << "New event of type " << type; - - switch (type) { - case static_cast(MessageType::SubmitTask): { - auto message = - flatbuffers::GetRoot( - input); - TaskExecutionSpec execution_spec = - TaskExecutionSpec(from_flatbuf(*message->execution_dependencies()), - (TaskSpec *) message->task_spec()->data(), - message->task_spec()->size()); - /* Set the tasks's local scheduler entrypoint time. */ - execution_spec.SetLastTimeStamp(current_time_ms()); - TaskSpec *spec = execution_spec.Spec(); - /* Update the result table, which holds mappings of object ID -> ID of the - * task that created it. */ - if (state->db != NULL) { - TaskID task_id = TaskSpec_task_id(spec); - for (int64_t i = 0; i < TaskSpec_num_returns(spec); ++i) { - ObjectID return_id = TaskSpec_return(spec, i); - result_table_add(state->db, return_id, task_id, false, NULL, NULL, - NULL); - } - } - - /* Handle the task submission. */ - if (TaskSpec_actor_id(spec).is_nil()) { - handle_task_submitted(state, state->algorithm_state, execution_spec); - } else { - handle_actor_task_submitted(state, state->algorithm_state, - execution_spec); - } - } break; - case static_cast(MessageType::TaskDone): { - } break; - case static_cast(MessageType::DisconnectClient): { - finish_task(state, worker); - RAY_CHECK(!worker->disconnected); - worker->disconnected = true; - /* If the disconnected worker was not an actor, start a new worker to make - * sure there are enough workers in the pool. */ - if (worker->actor_id.is_nil()) { - start_worker(state); - } - } break; - case static_cast(MessageType::EventLogMessage): { - /* Parse the message. */ - auto message = - flatbuffers::GetRoot( - input); - if (state->db != NULL) { - RayLogger_log_event(state->db, (uint8_t *) message->key()->data(), - message->key()->size(), - (uint8_t *) message->value()->data(), - message->value()->size(), message->timestamp()); - } - } break; - case static_cast(MessageType::RegisterClientRequest): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::RegisterClientRequest>(input); - handle_client_register(state, worker, message); - } break; - case static_cast(MessageType::GetTask): { - /* If this worker reports a completed task, account for resources. */ - finish_task(state, worker); - /* Let the scheduling algorithm process the fact that there is an available - * worker. */ - if (worker->actor_id.is_nil()) { - handle_worker_available(state, state->algorithm_state, worker); - } else { - handle_actor_worker_available(state, state->algorithm_state, worker); - } - } break; - case static_cast(MessageType::ReconstructObjects): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::ReconstructObjects>(input); - RAY_CHECK(!message->fetch_only()); - if (worker->task_in_progress != NULL && !worker->is_blocked) { - /* If the worker was executing a task (i.e. non-driver) and it wasn't - * already blocked on an object that's not locally available, update its - * state to blocked. */ - worker->is_blocked = true; - // Return the CPU resources that the blocked worker was using, but not - // other resources. If the worker is an actor, this will not return the - // CPU resources that the worker has acquired for its lifetime. It will - // only return the ones associated with the current method. - TaskSpec *spec = - Task_task_execution_spec(worker->task_in_progress)->Spec(); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - release_resources(state, worker, cpu_resources); - /* Let the scheduling algorithm process the fact that the worker is - * blocked. */ - if (worker->actor_id.is_nil()) { - handle_worker_blocked(state, state->algorithm_state, worker); - } else { - handle_actor_worker_blocked(state, state->algorithm_state, worker); - } - print_worker_info("Reconstructing", state->algorithm_state); - } - RAY_CHECK(message->object_ids()->size() == 1); - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(0)); - reconstruct_object(state, object_id); - } break; - case static_cast(CommonMessageType::DISCONNECT_CLIENT): { - RAY_LOG(DEBUG) << "Disconnecting client on fd " << client_sock; - handle_client_disconnect(state, worker); - } break; - case static_cast(MessageType::NotifyUnblocked): { - /* TODO(rkn): A driver may call this as well, right? */ - if (worker->task_in_progress != NULL) { - /* If the worker was executing a task (i.e. non-driver), update its - * state to not blocked. */ - RAY_CHECK(worker->is_blocked); - worker->is_blocked = false; - /* Lease back the CPU resources that the blocked worker needs (note that - * it never released its GPU resources). TODO(swang): Leasing back the - * resources to blocked workers can cause us to transiently exceed the - * maximum number of resources. This could be fixed by having blocked - * workers explicitly yield and wait to be given back resources before - * continuing execution. */ - TaskSpec *spec = - Task_task_execution_spec(worker->task_in_progress)->Spec(); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - acquire_resources(state, worker, cpu_resources); - /* Let the scheduling algorithm process the fact that the worker is - * unblocked. */ - if (worker->actor_id.is_nil()) { - handle_worker_unblocked(state, state->algorithm_state, worker); - } else { - handle_actor_worker_unblocked(state, state->algorithm_state, worker); - } - } - print_worker_info("Worker unblocked", state->algorithm_state); - } break; - case static_cast(MessageType::PutObject): { - auto message = - flatbuffers::GetRoot(input); - result_table_add(state->db, from_flatbuf(*message->object_id()), - from_flatbuf(*message->task_id()), true, NULL, NULL, NULL); - } break; - case static_cast(MessageType::GetActorFrontierRequest): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::GetActorFrontierRequest>(input); - ActorID actor_id = from_flatbuf(*message->actor_id()); - handle_get_actor_frontier(state, worker, actor_id); - } break; - case static_cast(MessageType::SetActorFrontier): { - auto message = - flatbuffers::GetRoot( - input); - handle_set_actor_frontier(state, worker, *message); - } break; - default: - /* This code should be unreachable. */ - RAY_CHECK(0); - } - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "process_message of type " << type << " took " - << end_time - start_time << " milliseconds."; - } -} - -void new_client_connection(event_loop *loop, - int listener_sock, - void *context, - int events) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - int new_socket = accept_client(listener_sock); - /* Create a struct for this worker. This will be freed when we free the local - * scheduler state. */ - LocalSchedulerClient *worker = new LocalSchedulerClient(); - worker->sock = new_socket; - worker->registered = false; - worker->disconnected = false; - /* We don't know whether this is a worker or not, so just initialize is_worker - * to false. */ - worker->is_worker = true; - worker->client_id = WorkerID::nil(); - worker->task_in_progress = NULL; - worker->is_blocked = false; - worker->pid = 0; - worker->is_child = false; - worker->actor_id = ActorID::nil(); - worker->local_scheduler_state = state; - state->workers.push_back(worker); - event_loop_add_file(loop, new_socket, EVENT_LOOP_READ, process_message, - worker); - RAY_LOG(DEBUG) << "new connection with fd " << new_socket; -} - -/* We need this code so we can clean up when we get a SIGTERM signal. */ - -LocalSchedulerState *g_state = NULL; - -void signal_handler(int signal) { - RAY_LOG(DEBUG) << "Signal was " << signal; - if (signal == SIGTERM) { - /* NOTE(swang): This call removes the SIGTERM handler to ensure that we - * free the local scheduler state at most once. If another SIGTERM is - * caught during this call, there is the possibility of orphan worker - * processes. */ - if (g_state) { - LocalSchedulerState_free(g_state); - } - exit(0); - } -} - -/* End of the cleanup code. */ - -void handle_task_scheduled_callback(Task *original_task, - void *subscribe_context) { - LocalSchedulerState *state = (LocalSchedulerState *) subscribe_context; - TaskExecutionSpec *execution_spec = Task_task_execution_spec(original_task); - TaskSpec *spec = execution_spec->Spec(); - - /* Set the tasks's local scheduler entrypoint time. */ - execution_spec->SetLastTimeStamp(current_time_ms()); - - /* If the driver for this task has been removed, then don't bother telling the - * scheduling algorithm. */ - WorkerID driver_id = TaskSpec_driver_id(spec); - if (!is_driver_alive(state, driver_id)) { - RAY_LOG(DEBUG) << "Ignoring scheduled task for removed driver."; - return; - } - - if (TaskSpec_actor_id(spec).is_nil()) { - /* This task does not involve an actor. Handle it normally. */ - handle_task_scheduled(state, state->algorithm_state, *execution_spec); - } else { - /* This task involves an actor. Call the scheduling algorithm's actor - * handler. */ - handle_actor_task_scheduled(state, state->algorithm_state, *execution_spec); - } -} - -/** - * Process a notification about the creation of a new actor. Use this to update - * the mapping from actor ID to the local scheduler ID of the local scheduler - * that is responsible for the actor. If this local scheduler is responsible for - * the actor, then launch a new worker process to create that actor. - * - * @param actor_id The ID of the actor being created. - * @param local_scheduler_id The ID of the local scheduler that is responsible - * for creating the actor. - * @param context The context for this callback. - * @return Void. - */ -void handle_actor_creation_callback(const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id, - void *context) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - - /* If the driver has been removed, don't bother doing anything. */ - if (state->removed_drivers.count(driver_id) == 1) { - return; - } - - // TODO(rkn): If we do not have perfect task suppression and it is possible - // for a task to be executed simultaneously on two nodes, then we will need to - // detect and handle that case. - - if (state->actor_mapping.count(actor_id) != 0) { - // This actor already exists. - auto it = state->actor_mapping.find(actor_id); - if (it->second.local_scheduler_id == get_db_client_id(state->db)) { - // TODO(rkn): The actor was previously assigned to this local scheduler. - // We should kill the actor here if it is still around. Also, if it hasn't - // registered yet, we should keep track of its PID so we can kill it - // anyway. - // TODO(swang): Evict actor dummy objects as part of actor cleanup. - } - } - - /* Create a new entry and add it to the actor mapping table. TODO(rkn): - * Currently this is never removed (except when the local scheduler state is - * deleted). */ - ActorMapEntry entry; - entry.local_scheduler_id = local_scheduler_id; - entry.driver_id = driver_id; - state->actor_mapping[actor_id] = entry; - - /* Let the scheduling algorithm process the fact that a new actor has been - * created. */ - handle_actor_creation_notification(state, state->algorithm_state, actor_id); -} - -int heartbeat_handler(event_loop *loop, timer_id id, void *context) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - // Spillback policy invocation is synchronized with the heartbeats. - spillback_tasks_handler(state); - - /* Check that the last heartbeat was not sent too long ago. */ - int64_t current_time = current_time_ms(); - RAY_CHECK(current_time >= state->previous_heartbeat_time); - if (current_time - state->previous_heartbeat_time > - RayConfig::instance().num_heartbeats_timeout() * - RayConfig::instance().heartbeat_timeout_milliseconds()) { - RAY_LOG(FATAL) << "The last heartbeat was sent " - << current_time - state->previous_heartbeat_time - << " milliseconds ago."; - } - state->previous_heartbeat_time = current_time; - - LocalSchedulerInfo info; - /* Ask the scheduling algorithm to fill out the scheduler info struct. */ - provide_scheduler_info(state, algorithm_state, &info); - /* Publish the heartbeat to all subscribers of the local scheduler table. */ - local_scheduler_table_send_info(state->db, &info, NULL); - /* Reset the timer. */ - return RayConfig::instance().heartbeat_timeout_milliseconds(); -} - -void start_server( - const char *node_ip_address, - const char *socket_name, - const char *redis_primary_addr, - int redis_primary_port, - const char *plasma_store_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_conf, - const char *start_worker_command, - int num_workers) { - /* Ignore SIGPIPE signals. If we don't do this, then when we attempt to write - * to a client that has already died, the local scheduler could die. */ - signal(SIGPIPE, SIG_IGN); - /* Ignore SIGCHLD signals. If we don't do this, then worker processes will - * become zombies instead of dying gracefully. */ - signal(SIGCHLD, SIG_IGN); - int fd = bind_ipc_sock(socket_name, true); - event_loop *loop = event_loop_create(); - g_state = LocalSchedulerState_init( - node_ip_address, loop, redis_primary_addr, redis_primary_port, - socket_name, plasma_store_socket_name, plasma_manager_socket_name, - plasma_manager_address, global_scheduler_exists, static_resource_conf, - start_worker_command, num_workers); - /* Register a callback for registering new clients. */ - event_loop_add_file(loop, fd, EVENT_LOOP_READ, new_client_connection, - g_state); - /* Subscribe to receive notifications about tasks that are assigned to this - * local scheduler by the global scheduler or by other local schedulers. - * TODO(rkn): we also need to get any tasks that were assigned to this local - * scheduler before the call to subscribe. */ - if (g_state->db != NULL) { - task_table_subscribe(g_state->db, get_db_client_id(g_state->db), - TaskStatus::SCHEDULED, handle_task_scheduled_callback, - g_state, NULL, NULL, NULL); - } - /* Subscribe to notifications about newly created actors. */ - if (g_state->db != NULL) { - actor_notification_table_subscribe( - g_state->db, handle_actor_creation_callback, g_state, NULL); - } - /* Subscribe to notifications about removed drivers. */ - if (g_state->db != NULL) { - driver_table_subscribe(g_state->db, handle_driver_removed_callback, g_state, - NULL); - } - /* Create a timer for publishing information about the load on the local - * scheduler to the local scheduler table. This message also serves as a - * heartbeat. */ - if (g_state->db != NULL) { - event_loop_add_timer(loop, - RayConfig::instance().heartbeat_timeout_milliseconds(), - heartbeat_handler, g_state); - } - /* Listen for new and deleted db clients. */ - if (g_state->db != NULL) { - db_client_table_cache_init(g_state->db); - } - /* Create a timer for fetching queued tasks' missing object dependencies. */ - event_loop_add_timer( - loop, RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(), - fetch_object_timeout_handler, g_state); - /* Create a timer for initiating the reconstruction of tasks' missing object - * dependencies. */ - event_loop_add_timer( - loop, RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(), - reconstruct_object_timeout_handler, g_state); - // Create a timer for rerunning actor creation tasks for actor tasks that are - // cached locally. - event_loop_add_timer( - loop, RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(), - rerun_actor_creation_tasks_timeout_handler, g_state); - /* Run event loop. */ - event_loop_run(loop); -} - -/* Only declare the main function if we are not in testing mode, since the test - * suite has its own declaration of main. */ -#ifndef LOCAL_SCHEDULER_TEST -int main(int argc, char *argv[]) { - InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); - ray::RayLog::InstallFailureSignalHandler(); - signal(SIGTERM, signal_handler); - /* Path of the listening socket of the local scheduler. */ - char *scheduler_socket_name = NULL; - /* IP address and port of the primary redis instance. */ - char *redis_primary_addr_port = NULL; - /* Socket name for the local Plasma store. */ - char *plasma_store_socket_name = NULL; - /* Socket name for the local Plasma manager. */ - char *plasma_manager_socket_name = NULL; - /* Address for the plasma manager associated with this local scheduler - * instance. */ - char *plasma_manager_address = NULL; - /* The IP address of the node that this local scheduler is running on. */ - char *node_ip_address = NULL; - /* Comma-separated list of configured resource capabilities for this node. */ - char *static_resource_list = NULL; - std::unordered_map static_resource_conf; - /* The command to run when starting new workers. */ - char *start_worker_command = NULL; - /* The number of workers to start. */ - char *num_workers_str = NULL; - int c; - bool global_scheduler_exists = true; - while ((c = getopt(argc, argv, "s:r:p:m:ga:h:c:w:n:")) != -1) { - switch (c) { - case 's': - scheduler_socket_name = optarg; - break; - case 'r': - redis_primary_addr_port = optarg; - break; - case 'p': - plasma_store_socket_name = optarg; - break; - case 'm': - plasma_manager_socket_name = optarg; - break; - case 'g': - global_scheduler_exists = false; - break; - case 'a': - plasma_manager_address = optarg; - break; - case 'h': - node_ip_address = optarg; - break; - case 'c': - static_resource_list = optarg; - break; - case 'w': - start_worker_command = optarg; - break; - case 'n': - num_workers_str = optarg; - break; - default: - RAY_LOG(FATAL) << "unknown option " << c; - } - } - if (!static_resource_list) { - RAY_LOG(FATAL) << "please specify a static resource list with the -c " - << "switch"; - } - // Parse the resource list. - std::istringstream resource_string(static_resource_list); - std::string resource_name; - std::string resource_quantity; - - while (std::getline(resource_string, resource_name, ',')) { - RAY_CHECK(std::getline(resource_string, resource_quantity, ',')); - // TODO(rkn): The line below could throw an exception. What should we do - // about this? - static_resource_conf[resource_name] = std::stod(resource_quantity); - } - - if (!scheduler_socket_name) { - RAY_LOG(FATAL) << "please specify socket for incoming connections with " - << "-s switch"; - } - if (!plasma_store_socket_name) { - RAY_LOG(FATAL) << "please specify socket for connecting to Plasma store " - << "with -p switch"; - } - if (!node_ip_address) { - RAY_LOG(FATAL) << "please specify the node IP address with -h switch"; - } - int num_workers = 0; - if (num_workers_str) { - num_workers = strtol(num_workers_str, NULL, 10); - if (num_workers < 0) { - RAY_LOG(FATAL) << "Number of workers must be nonnegative"; - } - } - - char redis_primary_addr[16]; - char *redis_addr = NULL; - int redis_port = -1; - if (!redis_primary_addr_port) { - /* Start the local scheduler without connecting to Redis. In this case, all - * submitted tasks will be queued and scheduled locally. */ - if (plasma_manager_socket_name) { - RAY_LOG(FATAL) << "if a plasma manager socket name is provided with the " - << "-m switch, then a redis address must be provided with " - << "the -r switch"; - } - } else { - int redis_primary_port; - /* Parse the primary Redis address into an IP address and a port. */ - if (parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, - &redis_primary_port) == -1) { - RAY_LOG(FATAL) << "if a redis address is provided with the -r switch, it " - << "should be formatted like 127.0.0.1:6379"; - } - if (!plasma_manager_socket_name) { - RAY_LOG(FATAL) << "please specify socket for connecting to Plasma " - << "manager with -m switch"; - } - redis_addr = redis_primary_addr; - redis_port = redis_primary_port; - } - - start_server(node_ip_address, scheduler_socket_name, redis_addr, redis_port, - plasma_store_socket_name, plasma_manager_socket_name, - plasma_manager_address, global_scheduler_exists, - static_resource_conf, start_worker_command, num_workers); -} -#endif diff --git a/src/local_scheduler/local_scheduler.h b/src/local_scheduler/local_scheduler.h deleted file mode 100644 index 39c7523fe7ed4..0000000000000 --- a/src/local_scheduler/local_scheduler.h +++ /dev/null @@ -1,176 +0,0 @@ -#ifndef LOCAL_SCHEDULER_H -#define LOCAL_SCHEDULER_H - -#include "event_loop.h" -#include "local_scheduler_shared.h" -#include "task.h" - -/** - * Establish a connection to a new client. - * - * @param loop Event loop of the local scheduler. - * @param listener_socket Socket the local scheduler is listening on for new - * client requests. - * @param context State of the local scheduler. - * @param events Flag for events that are available on the listener socket. - * @return Void. - */ -void new_client_connection(event_loop *loop, - int listener_sock, - void *context, - int events); - -/** - * Check if a driver is still alive. - * - * @param driver_id The ID of the driver. - * @return True if the driver is still alive and false otherwise. - */ -bool is_driver_alive(WorkerID driver_id); - -/** - * This function can be called by the scheduling algorithm to assign a task - * to a worker. - * - * @param info - * @param task The task that is submitted to the worker. - * @param worker The worker to assign the task to. - * @return Void. - */ -void assign_task_to_worker(LocalSchedulerState *state, - TaskExecutionSpec &task, - LocalSchedulerClient *worker); - -/* - * This function is called whenever a task has finished on one of the workers. - * It updates the resource accounting and the global state store. - * - * @param state The local scheduler state. - * @param worker The worker that finished the task. - * @return Void. - */ -void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker); - -/** - * This is the callback that is used to process a notification from the Plasma - * store that an object has been sealed. - * - * @param loop The local scheduler's event loop. - * @param client_sock The file descriptor to read the notification from. - * @param context The local scheduler state. - * @param events - * @return Void. - */ -void process_plasma_notification(event_loop *loop, - int client_sock, - void *context, - int events); - -/** - * Reconstruct an object. If the object does not exist on any nodes, according - * to the state tables, and if the object is not already being reconstructed, - * this triggers a single reexecution of the task that originally created the - * object. - * - * @param state The local scheduler state. - * @param object_id The ID of the object to reconstruct. - * @return Void. - */ -void reconstruct_object(LocalSchedulerState *state, ObjectID object_id); - -void print_resource_info(const LocalSchedulerState *s, const TaskSpec *spec); - -/** - * Kill a worker, if it is a child process, and clean up all of its associated - * state. Note that this function is also called on drivers, but it should not - * actually send a kill signal to drivers. - * - * @param state The local scheduler state. - * @param worker The local scheduler client to kill. - * @param wait A boolean representing whether to wait for the killed worker to - * exit. - * @param suppress_warning A bool that is true if we should not warn the driver, - * and false otherwise. This should only be true when a driver is - * removed. - * @return Void. - */ -void kill_worker(LocalSchedulerState *state, - LocalSchedulerClient *worker, - bool wait, - bool suppress_warning); - -/** - * Start a worker. This forks a new worker process that can be added to the - * pool of available workers, pending registration of its PID with the local - * scheduler. - * - * @param state The local scheduler state. - * @param Void. - */ -void start_worker(LocalSchedulerState *state); - -/** - * Check if a certain quantity of dynamic resources are available. If num_cpus - * is 0, we ignore the dynamic number of available CPUs (which may be negative). - * - * @param state The state of the local scheduler. - * @param resources The resources to check. - * @return True if there are enough CPUs and GPUs and false otherwise. - */ -bool check_dynamic_resources( - LocalSchedulerState *state, - const std::unordered_map &resources); - -/** - * Acquire additional resources (CPUs and GPUs) for a worker. - * - * @param state The local scheduler state. - * @param worker The worker who is acquiring resources. - * @param resources The resources to acquire. - * @return Void. - */ -void acquire_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources); - -/** - * Return resources (CPUs and GPUs) being used by a worker to the local - * scheduler. - * - * @param state The local scheduler state. - * @param worker The worker who is returning resources. - * @param resources The resources to release. - * @return Void. - */ -void release_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources); - -/** The following methods are for testing purposes only. */ -#ifdef LOCAL_SCHEDULER_TEST -LocalSchedulerState *LocalSchedulerState_init( - const char *node_ip_address, - event_loop *loop, - const char *redis_addr, - int redis_port, - const char *local_scheduler_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_store_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_vector, - const char *worker_path, - int num_workers); - -SchedulingAlgorithmState *get_algorithm_state(LocalSchedulerState *state); - -void process_message(event_loop *loop, - int client_sock, - void *context, - int events); - -#endif - -#endif /* LOCAL_SCHEDULER_H */ diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc deleted file mode 100644 index 89d6c8d6df56c..0000000000000 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ /dev/null @@ -1,1851 +0,0 @@ -#include "local_scheduler_algorithm.h" - -#include -#include -#include - -#include "state/task_table.h" -#include "state/actor_notification_table.h" -#include "state/db_client_table.h" -#include "state/error_table.h" -#include "state/local_scheduler_table.h" -#include "state/object_table.h" -#include "local_scheduler_shared.h" -#include "local_scheduler.h" -#include "common/task.h" - -/* Declared for convenience. */ -void remove_actor(SchedulingAlgorithmState *algorithm_state, ActorID actor_id); - -void give_task_to_global_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -void give_task_to_local_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - DBClientID local_scheduler_id); - -void clear_missing_dependencies(SchedulingAlgorithmState *algorithm_state, - std::list::iterator it); - -/** A data structure used to track which objects are available locally and - * which objects are being actively fetched. Objects of this type are used for - * both the scheduling algorithm state's local_objects and remote_objects - * tables. An ObjectEntry should be in at most one of the tables and not both - * simultaneously. */ -struct ObjectEntry { - /** A vector of tasks dependent on this object. These tasks are a subset of - * the tasks in the waiting queue. Each element actually stores a reference - * to the corresponding task's queue entry in waiting queue, for fast - * deletion when all of the task's dependencies become available. */ - std::vector::iterator> dependent_tasks; - /** Whether or not to request a transfer of this object. This should be set - * to true for all objects except for actor dummy objects, where the object - * must be generated by executing the task locally. */ - bool request_transfer; -}; - -/** This struct contains information about a specific actor. This struct will be - * used inside of a hash table. */ -typedef struct { - /** The number of tasks that have been executed on this actor so far, per - * handle. This is used to guarantee execution of tasks on actors in the - * order that the tasks were submitted, per handle. Tasks from different - * handles to the same actor may be interleaved. */ - std::unordered_map task_counters; - /** These are the execution dependencies that make up the frontier of the - * actor's runnable tasks. For each actor handle, we store the object ID - * that represents the execution dependency for the next runnable task - * submitted by that handle. */ - std::unordered_map frontier_dependencies; - /** The return value of the most recently executed task. The next task to - * execute should take this as an execution dependency at dispatch time. Set - * to nil if there are no execution dependencies (e.g., this is the first - * task to execute). */ - ObjectID execution_dependency; - /** A queue of tasks to be executed on this actor. The tasks will be sorted by - * the order of their actor counters. */ - std::list *task_queue; - /** The worker that the actor is running on. */ - LocalSchedulerClient *worker; - /** True if the worker is available and false otherwise. */ - bool worker_available; -} LocalActorInfo; - -/** Part of the local scheduler state that is maintained by the scheduling - * algorithm. */ -struct SchedulingAlgorithmState { - /** An array of pointers to tasks that are waiting for dependencies. */ - std::list *waiting_task_queue; - /** An array of pointers to tasks whose dependencies are ready but that are - * waiting to be assigned to a worker. */ - std::list *dispatch_task_queue; - /** This is a hash table from actor ID to information about that actor. In - * particular, a queue of tasks that are waiting to execute on that actor. - * This is only used for actors that exist locally. */ - std::unordered_map local_actor_infos; - /** This is a set of the IDs of the actors that have tasks waiting to run. - * The purpose is to make it easier to dispatch tasks without looping over - * all of the actors. Note that this is an optimization and is not strictly - * necessary. */ - std::unordered_set actors_with_pending_tasks; - /** A vector of actor tasks that have been submitted but this local scheduler - * doesn't know which local scheduler is responsible for them, so cannot - * assign them to the correct local scheduler yet. Whenever a notification - * about a new local scheduler arrives, we will resubmit all of these tasks - * locally. */ - std::vector cached_submitted_actor_tasks; - /** An array of pointers to workers in the worker pool. These are workers - * that have registered a PID with us and that are now waiting to be - * assigned a task to execute. */ - std::vector available_workers; - /** An array of pointers to workers that are currently executing a task, - * unblocked. These are the workers that are leasing some number of - * resources. */ - std::vector executing_workers; - /** An array of pointers to workers that are currently executing a task, - * blocked on some object(s) that isn't available locally yet. These are the - * workers that are executing a task, but that have temporarily returned the - * task's required resources. */ - std::vector blocked_workers; - /** A hash map of the objects that are available in the local Plasma store. - * The key is the object ID. This information could be a little stale. */ - std::unordered_map local_objects; - /** A hash map of the objects that are not available locally. These are - * currently being fetched by this local scheduler. The key is the object - * ID. Every local_scheduler_fetch_timeout_milliseconds, a Plasma fetch - * request will be sent the object IDs in this table. Each entry also holds - * an array of queued tasks that are dependent on it. */ - std::unordered_map remote_objects; -}; - -SchedulingAlgorithmState *SchedulingAlgorithmState_init(void) { - SchedulingAlgorithmState *algorithm_state = new SchedulingAlgorithmState(); - /* Initialize the local data structures used for queuing tasks and workers. */ - algorithm_state->waiting_task_queue = new std::list(); - algorithm_state->dispatch_task_queue = new std::list(); - - return algorithm_state; -} - -void SchedulingAlgorithmState_free(SchedulingAlgorithmState *algorithm_state) { - /* Free all of the tasks in the waiting queue. */ - delete algorithm_state->waiting_task_queue; - /* Free all the tasks in the dispatch queue. */ - delete algorithm_state->dispatch_task_queue; - /* Remove all of the remaining actors. */ - while (algorithm_state->local_actor_infos.size() != 0) { - auto it = algorithm_state->local_actor_infos.begin(); - ActorID actor_id = it->first; - remove_actor(algorithm_state, actor_id); - } - /* Free the algorithm state. */ - delete algorithm_state; -} - -/** - * This is a helper method to check if a worker is in a vector of workers. - * - * @param worker_vector A vector of workers. - * @param The worker to look for in the vector. - * @return True if the worker is in the vector and false otherwise. - */ -bool worker_in_vector(std::vector &worker_vector, - LocalSchedulerClient *worker) { - auto it = std::find(worker_vector.begin(), worker_vector.end(), worker); - return it != worker_vector.end(); -} - -/** - * This is a helper method to remove a worker from a vector of workers if it is - * present in the vector. - * - * @param worker_vector A vector of workers. - * @param The worker to remove. - * @return True if the worker was removed and false otherwise. - */ -bool remove_worker_from_vector( - std::vector &worker_vector, - LocalSchedulerClient *worker) { - /* Find the worker in the list of executing workers. */ - auto it = std::find(worker_vector.begin(), worker_vector.end(), worker); - bool remove_worker = (it != worker_vector.end()); - if (remove_worker) { - /* Remove the worker from the list of workers. */ - using std::swap; - swap(*it, worker_vector.back()); - worker_vector.pop_back(); - } - return remove_worker; -} - -void provide_scheduler_info(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerInfo *info) { - info->total_num_workers = state->workers.size(); - /* TODO(swang): Provide separate counts for tasks that are waiting for - * dependencies vs tasks that are waiting to be assigned. */ - int64_t waiting_task_queue_length = - algorithm_state->waiting_task_queue->size(); - int64_t dispatch_task_queue_length = - algorithm_state->dispatch_task_queue->size(); - info->task_queue_length = - waiting_task_queue_length + dispatch_task_queue_length; - info->available_workers = algorithm_state->available_workers.size(); - /* Copy static and dynamic resource information. */ - info->dynamic_resources = state->dynamic_resources; - info->static_resources = state->static_resources; -} - -/** - * Create the LocalActorInfo struct for an actor worker that this local - * scheduler is responsible for. For a given actor, this will either be done - * when the first task for that actor arrives or when the worker running that - * actor connects to the local scheduler. - * - * @param algorithm_state The state of the scheduling algorithm. - * @param actor_id The actor ID of the actor being created. - * @param initial_execution_dependency The dummy object ID of the actor - * creation task. - * @param worker The worker struct for the worker that is running this actor. - * If the worker struct has not been created yet (meaning that the worker - * that is running this actor has not registered with the local scheduler - * yet, and so create_actor is being called because a task for that actor - * has arrived), then this should be NULL. - * @return Void. - */ -void create_actor(SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker) { - LocalActorInfo entry; - entry.task_counters[ActorHandleID::nil()] = 0; - entry.frontier_dependencies[ActorHandleID::nil()] = ObjectID::nil(); - /* The actor has not yet executed any tasks, so there are no execution - * dependencies for the next task to be scheduled. */ - entry.execution_dependency = initial_execution_dependency; - entry.task_queue = new std::list(); - entry.worker = worker; - entry.worker_available = false; - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 0); - algorithm_state->local_actor_infos[actor_id] = entry; - - /* Log some useful information about the actor that we created. */ - RAY_LOG(DEBUG) << "Creating actor with ID " << actor_id; -} - -void remove_actor(SchedulingAlgorithmState *algorithm_state, ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 1); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - - /* Log some useful information about the actor that we're removing. */ - size_t count = entry.task_queue->size(); - if (count > 0) { - RAY_LOG(WARNING) << "Removing actor with ID " << actor_id << " and " - << count << " remaining tasks."; - } - - entry.task_queue->clear(); - delete entry.task_queue; - /* Remove the entry from the hash table. */ - algorithm_state->local_actor_infos.erase(actor_id); - - /* Remove the actor ID from the set of actors with pending tasks. */ - algorithm_state->actors_with_pending_tasks.erase(actor_id); -} - -/** - * Dispatch a task to an actor if possible. - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param actor_id The ID of the actor corresponding to the worker. - * @return True if a task was dispatched to the actor and false otherwise. - */ -bool dispatch_actor_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - /* Make sure this worker actually is an actor. */ - RAY_CHECK(!actor_id.is_nil()); - /* Return if this actor doesn't have any pending tasks. */ - if (algorithm_state->actors_with_pending_tasks.find(actor_id) == - algorithm_state->actors_with_pending_tasks.end()) { - return false; - } - /* Make sure this actor belongs to this local scheduler. */ - if (state->actor_mapping.count(actor_id) != 1) { - /* The creation notification for this actor has not yet arrived at the local - * scheduler. This should be rare. */ - return false; - } - RAY_CHECK(state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)); - - /* Get the local actor entry for this actor. */ - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - - /* There should be some queued tasks for this actor. */ - RAY_CHECK(!entry.task_queue->empty()); - /* If the worker is not available, we cannot assign a task to it. */ - if (!entry.worker_available) { - return false; - } - - /* Check whether we can execute the first task in the queue. */ - auto task = entry.task_queue->begin(); - TaskSpec *spec = task->Spec(); - ActorHandleID next_task_handle_id = TaskSpec_actor_handle_id(spec); - /* We can only execute tasks in order of task_counter. */ - if (TaskSpec_actor_counter(spec) != - entry.task_counters[next_task_handle_id]) { - return false; - } - - /* If there are not enough resources available, we cannot assign the task. */ - RAY_CHECK(0 == TaskSpec_get_required_resource(spec, "GPU")); - if (!check_dynamic_resources(state, TaskSpec_get_required_resources(spec))) { - return false; - } - - /* Update the task's execution dependencies to reflect the actual execution - * order to support deterministic reconstruction. */ - /* NOTE(swang): The update of an actor task's execution dependencies is - * performed asynchronously. This means that if this local scheduler dies, we - * may lose updates that are in flight to the task table. We only guarantee - * deterministic reconstruction ordering for tasks whose updates are - * reflected in the task table. */ - std::vector ordered_execution_dependencies; - ordered_execution_dependencies.push_back(entry.execution_dependency); - task->SetExecutionDependencies(ordered_execution_dependencies); - - /* Assign the first task in the task queue to the worker and mark the worker - * as unavailable. */ - assign_task_to_worker(state, *task, entry.worker); - entry.execution_dependency = TaskSpec_actor_dummy_object(spec); - entry.worker_available = false; - /* Extend the frontier to include the assigned task. */ - entry.task_counters[next_task_handle_id] += 1; - entry.frontier_dependencies[next_task_handle_id] = entry.execution_dependency; - - /* Remove the task from the actor's task queue. */ - entry.task_queue->erase(task); - /* If there are no more tasks in the queue, then indicate that the actor has - * no tasks. */ - if (entry.task_queue->empty()) { - algorithm_state->actors_with_pending_tasks.erase(actor_id); - } - - return true; -} - -void handle_convert_worker_to_actor( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker) { - if (algorithm_state->local_actor_infos.count(actor_id) == 0) { - create_actor(algorithm_state, actor_id, initial_execution_dependency, - worker); - } else { - /* In this case, the LocalActorInfo struct was already been created by the - * first call to add_task_to_actor_queue. However, the worker field was not - * filled out, so fill out the correct worker field now. */ - algorithm_state->local_actor_infos[actor_id].worker = worker; - } - /* Increment the task counter for the creator's handle to account for the - * actor creation task. */ - auto &task_counters = - algorithm_state->local_actor_infos[actor_id].task_counters; - RAY_CHECK(task_counters[ActorHandleID::nil()] == 0); - task_counters[ActorHandleID::nil()]++; -} - -/** - * Finishes a killed task by inserting dummy objects for each of its returns. - */ -void finish_killed_task(LocalSchedulerState *state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - int64_t num_returns = TaskSpec_num_returns(spec); - for (int i = 0; i < num_returns; i++) { - ObjectID object_id = TaskSpec_return(spec, i); - std::shared_ptr data; - // TODO(ekl): this writes an invalid arrow object, which is sufficient to - // signal that the worker failed, but it would be nice to return more - // detailed failure metadata in the future. - arrow::Status status = - state->plasma_conn->Create(object_id.to_plasma_id(), 1, NULL, 0, &data); - if (!status.IsPlasmaObjectExists()) { - ARROW_CHECK_OK(status); - ARROW_CHECK_OK(state->plasma_conn->Seal(object_id.to_plasma_id())); - } - } - /* Mark the task as done. */ - if (state->db != NULL) { - Task *task = Task_alloc(execution_spec, TaskStatus::DONE, - get_db_client_id(state->db)); - // In most cases, task_table_update would be appropriate, however, it is - // possible in some cases that the task has not yet been added to the task - // table (e.g., if it is an actor task that is queued locally because the - // actor has not been created yet). - task_table_add_task(state->db, task, NULL, NULL, NULL); - } -} - -/** - * Insert a task queue entry into an actor's dispatch queue. The task is - * inserted in sorted order by task counter. If this is the first task - * scheduled to this actor and the worker process has not yet connected, then - * this also creates a LocalActorInfo entry for the actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param task_entry The task queue entry to add to the actor's queue. - * @return Void. - */ -void insert_actor_task_queue(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec task_entry) { - TaskSpec *spec = task_entry.Spec(); - /* Get the local actor entry for this actor. */ - ActorID actor_id = TaskSpec_actor_id(spec); - ActorHandleID task_handle_id = TaskSpec_actor_handle_id(spec); - int64_t task_counter = TaskSpec_actor_counter(spec); - - /* Fail the task immediately; it's destined for a dead actor. */ - if (state->removed_actors.find(actor_id) != state->removed_actors.end()) { - finish_killed_task(state, task_entry); - return; - } - - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - if (entry.task_counters.count(task_handle_id) == 0) { - entry.task_counters[task_handle_id] = 0; - } - /* Extend the frontier to include the new handle. */ - if (entry.frontier_dependencies.count(task_handle_id) == 0) { - RAY_CHECK(task_entry.ExecutionDependencies().size() == 1); - entry.frontier_dependencies[task_handle_id] = - task_entry.ExecutionDependencies()[0]; - } - - /* As a sanity check, the counter of the new task should be greater than the - * number of tasks that have executed on this actor so far (since we are - * guaranteeing in-order execution of the tasks on the actor). TODO(rkn): This - * check will fail if the fault-tolerance mechanism resubmits a task on an - * actor. */ - if (task_counter < entry.task_counters[task_handle_id]) { - RAY_LOG(INFO) << "A task that has already been executed has been " - << "resubmitted, so we are ignoring it. This should only " - << "happen during reconstruction."; - return; - } - - /* Insert the task spec to the actor's task queue in sorted order, per actor - * handle ID. Find the first task in the queue with a counter greater than - * the submitted task's and the same handle ID. */ - auto it = entry.task_queue->begin(); - for (; it != entry.task_queue->end(); it++) { - TaskSpec *pending_task_spec = it->Spec(); - /* Skip tasks submitted by a different handle. */ - if (!(task_handle_id == TaskSpec_actor_handle_id(pending_task_spec))) { - continue; - } - /* A duplicate task submitted by the same handle. */ - if (task_counter == TaskSpec_actor_counter(pending_task_spec)) { - RAY_LOG(INFO) << "A task was resubmitted, so we are ignoring it. This " - << "should only happen during reconstruction."; - return; - } - /* We found a task with the same handle ID and a greater task counter. */ - if (task_counter < TaskSpec_actor_counter(pending_task_spec)) { - break; - } - } - entry.task_queue->insert(it, std::move(task_entry)); - - /* Record the fact that this actor has a task waiting to execute. */ - algorithm_state->actors_with_pending_tasks.insert(actor_id); -} - -/** - * Queue a task to be dispatched for an actor. Update the task table for the - * queued task. TODO(rkn): Should we also update the task table in the case - * where the tasks are cached locally? - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param spec The task spec to add. - * @param from_global_scheduler True if the task was assigned to this local - * scheduler by the global scheduler and false if it was submitted - * locally by a worker. - * @return Void. - */ -void queue_actor_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - TaskSpec *spec = execution_spec.Spec(); - ActorID actor_id = TaskSpec_actor_id(spec); - RAY_CHECK(!actor_id.is_nil()); - - /* Update the task table. */ - if (state->db != NULL) { - Task *task = Task_alloc(execution_spec, TaskStatus::QUEUED, - get_db_client_id(state->db)); - if (from_global_scheduler) { - /* If the task is from the global scheduler, it's already been added to - * the task table, so just update the entry. */ - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - /* Otherwise, this is the first time the task has been seen in the - * system (unless it's a resubmission of a previous task), so add the - * entry. */ - task_table_add_task(state->db, task, NULL, NULL, NULL); - } - } - - // Create a new task queue entry. This must come after the above block because - // insert_actor_task_queue may call task_table_update internally, which must - // come after the prior call to task_table_add_task. - TaskExecutionSpec copy = TaskExecutionSpec(&execution_spec); - insert_actor_task_queue(state, algorithm_state, std::move(copy)); -} - -/** - * Fetch a queued task's missing object dependency. The fetch request will be - * retried every local_scheduler_fetch_timeout_milliseconds until the object is - * available locally. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @param obj_id The ID of the object that the task is dependent on. - * @param request_transfer Whether to request a transfer of this object from - * other plasma managers. This should be set to false for execution - * dependencies, which should be fulfilled by executing the - * corresponding task locally. - * @returns Void. - */ -void fetch_missing_dependency( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it, - plasma::ObjectID obj_id, - bool request_transfer) { - if (algorithm_state->remote_objects.count(obj_id) == 0) { - /* We weren't actively fetching this object. Try the fetch once - * immediately. */ - if (state->plasma_conn->get_manager_fd() != -1) { - auto arrow_status = state->plasma_conn->Fetch(1, &obj_id); - if (!arrow_status.ok()) { - LocalSchedulerState_free(state); - /* TODO(swang): Local scheduler should also exit even if there are no - * pending fetches. This could be done by subscribing to the db_client - * table, or pinging the plasma manager in the heartbeat handler. */ - RAY_LOG(FATAL) << "Lost connection to the plasma manager, local " - << "scheduler is exiting. Error: " - << arrow_status.ToString(); - } - } - /* Create an entry and add it to the list of active fetch requests to - * ensure that the fetch actually happens. The entry will be moved to the - * hash table of locally available objects in handle_object_available when - * the object becomes available locally. It will get freed if the object is - * subsequently removed locally. */ - ObjectEntry entry; - entry.request_transfer = request_transfer; - algorithm_state->remote_objects[obj_id] = entry; - } - algorithm_state->remote_objects[obj_id].dependent_tasks.push_back( - task_entry_it); -} - -/** - * Fetch a queued task's missing object dependencies. The fetch requests will - * be retried every local_scheduler_fetch_timeout_milliseconds until all - * objects are available locally. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @returns Void. - */ -void fetch_missing_dependencies( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it) { - int64_t num_dependencies = task_entry_it->NumDependencies(); - int num_missing_dependencies = 0; - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = task_entry_it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task_entry_it->DependencyId(i, j); - /* If the entry is not yet available locally, record the dependency. */ - if (algorithm_state->local_objects.count(obj_id) == 0) { - /* Do not request a transfer from other plasma managers if this is an - * execution dependency. */ - bool request_transfer = task_entry_it->IsStaticDependency(i); - fetch_missing_dependency(state, algorithm_state, task_entry_it, - obj_id.to_plasma_id(), request_transfer); - ++num_missing_dependencies; - } - } - } - RAY_CHECK(num_missing_dependencies > 0); -} - -/** - * Clear a queued task's missing object dependencies. This is the inverse of - * fetch_missing_dependencies. - * TODO(swang): Test this function. - * - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @returns Void. - */ -void clear_missing_dependencies( - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it) { - int64_t num_dependencies = task_entry_it->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = task_entry_it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task_entry_it->DependencyId(i, j); - /* If this object dependency is missing, remove this task from the - * object's list of dependent tasks. */ - auto entry = algorithm_state->remote_objects.find(obj_id); - if (entry != algorithm_state->remote_objects.end()) { - /* Find and remove the given task. */ - auto &dependent_tasks = entry->second.dependent_tasks; - for (auto dependent_task_it = dependent_tasks.begin(); - dependent_task_it != dependent_tasks.end();) { - if (*dependent_task_it == task_entry_it) { - dependent_task_it = dependent_tasks.erase(dependent_task_it); - } else { - dependent_task_it++; - } - } - /* If the missing object dependency has no more dependent tasks, then - * remove it. */ - if (dependent_tasks.empty()) { - algorithm_state->remote_objects.erase(entry); - } - } - } - } -} - -/** - * Check if all of the remote object arguments for a task are available in the - * local object store. - * - * @param algorithm_state The scheduling algorithm state. - * @param task Task specification of the task to check. - * @return bool This returns true if all of the remote object arguments for the - * task are present in the local object store, otherwise it returns - * false. - */ -bool can_run(SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &task) { - int64_t num_dependencies = task.NumDependencies(); - for (int i = 0; i < num_dependencies; ++i) { - int count = task.DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task.DependencyId(i, j); - if (algorithm_state->local_objects.count(obj_id) == 0) { - /* The object is not present locally, so this task cannot be scheduled - * right now. */ - return false; - } - } - } - return true; -} - -bool object_locally_available(SchedulingAlgorithmState *algorithm_state, - ObjectID object_id) { - return algorithm_state->local_objects.count(object_id) == 1; -} - -/* TODO(swang): This method is not covered by any valgrind tests. */ -int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - /* Only try the fetches if we are connected to the object store manager. */ - if (state->plasma_conn->get_manager_fd() == -1) { - RAY_LOG(INFO) - << "Local scheduler is not connected to a object store manager"; - return RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(); - } - - std::vector object_id_vec; - for (auto const &entry : state->algorithm_state->remote_objects) { - if (entry.second.request_transfer) { - object_id_vec.push_back(entry.first); - } - } - - ObjectID *object_ids = object_id_vec.data(); - int64_t num_object_ids = object_id_vec.size(); - - /* Divide very large fetch requests into smaller fetch requests so that a - * single fetch request doesn't block the plasma manager for a long time. */ - for (int64_t j = 0; j < num_object_ids; - j += RayConfig::instance().local_scheduler_fetch_request_size()) { - int num_objects_in_request = - std::min( - num_object_ids, - j + RayConfig::instance().local_scheduler_fetch_request_size()) - - j; - auto arrow_status = state->plasma_conn->Fetch( - num_objects_in_request, - reinterpret_cast(&object_ids[j])); - if (!arrow_status.ok()) { - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Lost connection to the plasma manager, local " - << "scheduler is exiting. Error: " - << arrow_status.ToString(); - } - } - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "fetch_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - /* Wait at least local_scheduler_fetch_timeout_milliseconds before running - * this timeout handler again. But if we're waiting for a large number of - * objects, wait longer (e.g., 10 seconds for one million objects) so that we - * don't overwhelm the plasma manager. */ - return std::max( - RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(), - int64_t(0.01 * num_object_ids)); -} - -/* TODO(swang): This method is not covered by any valgrind tests. */ -int reconstruct_object_timeout_handler(event_loop *loop, - timer_id id, - void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - - /* This vector is used to track which object IDs to reconstruct next. If the - * vector is empty, we repopulate it with all of the keys of the remote object - * table. During every pass through this handler, we call reconstruct on up to - * max_num_to_reconstruct elements of the vector (after first checking that - * the object IDs are still missing). */ - static std::vector object_ids_to_reconstruct; - - /* If the set is empty, repopulate it. */ - if (object_ids_to_reconstruct.size() == 0) { - for (auto const &entry : state->algorithm_state->remote_objects) { - object_ids_to_reconstruct.push_back(entry.first); - } - } - - int64_t num_reconstructed = 0; - for (size_t i = 0; i < object_ids_to_reconstruct.size(); i++) { - ObjectID object_id = object_ids_to_reconstruct[i]; - /* Only call reconstruct if we are still missing the object. */ - if (state->algorithm_state->remote_objects.find(object_id) != - state->algorithm_state->remote_objects.end()) { - reconstruct_object(state, object_id); - } - num_reconstructed++; - if (num_reconstructed == RayConfig::instance().max_num_to_reconstruct()) { - break; - } - } - object_ids_to_reconstruct.erase( - object_ids_to_reconstruct.begin(), - object_ids_to_reconstruct.begin() + num_reconstructed); - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "reconstruct_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - return RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(); -} - -int rerun_actor_creation_tasks_timeout_handler(event_loop *loop, - timer_id id, - void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - - // Create a set of the dummy object IDs for the actor creation tasks to - // reconstruct. - std::unordered_set actor_dummy_objects; - for (auto const &execution_spec : - state->algorithm_state->cached_submitted_actor_tasks) { - ObjectID actor_creation_dummy_object_id = - TaskSpec_actor_creation_dummy_object_id(execution_spec.Spec()); - actor_dummy_objects.insert(actor_creation_dummy_object_id); - } - - // Issue reconstruct calls. - for (auto const &object_id : actor_dummy_objects) { - reconstruct_object(state, object_id); - } - - // Print a warning if this method took too long. - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "reconstruct_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - return RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(); -} - -/** - * Return true if there are still some resources available and false otherwise. - * - * @param state The scheduler state. - * @return True if there are still some resources and false if there are not. - */ -bool resources_available(LocalSchedulerState *state) { - bool resources_available = false; - for (auto const &resource_pair : state->dynamic_resources) { - if (resource_pair.second > 0) { - resources_available = true; - } - } - return resources_available; -} - -void spillback_tasks_handler(LocalSchedulerState *state) { - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - int64_t num_to_spillback = std::min( - static_cast(algorithm_state->dispatch_task_queue->size()), - RayConfig::instance().max_tasks_to_spillback()); - - auto it = algorithm_state->dispatch_task_queue->end(); - for (int64_t i = 0; i < num_to_spillback; i++) { - it--; - } - - for (int64_t i = 0; i < num_to_spillback; i++) { - it->IncrementSpillbackCount(); - // If an actor hasn't been created for a while, push a warning to the - // driver. - if (it->SpillbackCount() % - RayConfig::instance().actor_creation_num_spillbacks_warning() == - 0) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_is_actor_creation_task(spec)) { - std::ostringstream error_message; - error_message << "The actor with ID " - << TaskSpec_actor_creation_id(spec) << " is taking a " - << "while to be created. It is possible that the " - << "cluster does not have enough resources to place this " - << "actor (this may be normal while an autoscaling " - << "is scaling up). Consider reducing the number of " - << "actors created, or " - << "increasing the number of slots available by using " - << "the --num-cpus, --num-gpus, and --resources flags. " - << "The actor creation task is requesting "; - for (auto const &resource_pair : - TaskSpec_get_required_resources(spec)) { - error_message << resource_pair.second << " " << resource_pair.first - << " "; - } - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::ACTOR_NOT_CREATED, error_message.str()); - } - } - - give_task_to_global_scheduler(state, algorithm_state, *it); - // Dequeue the task. - it = algorithm_state->dispatch_task_queue->erase(it); - } -} - -/** - * Assign as many tasks from the dispatch queue as possible. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @return Void. - */ -void dispatch_tasks(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state) { - /* Assign as many tasks as we can, while there are workers available. */ - for (auto it = algorithm_state->dispatch_task_queue->begin(); - it != algorithm_state->dispatch_task_queue->end();) { - TaskSpec *spec = it->Spec(); - /* If there is a task to assign, but there are no more available workers in - * the worker pool, then exit. Ensure that there will be an available - * worker during a future invocation of dispatch_tasks. */ - if (algorithm_state->available_workers.size() == 0) { - if (state->child_pids.size() == 0) { - /* If there are no workers, including those pending PID registration, - * then we must start a new one to replenish the worker pool. */ - start_worker(state); - } - return; - } - - /* Terminate early if there are no more resources available. */ - if (!resources_available(state)) { - return; - } - - /* Skip to the next task if this task cannot currently be satisfied. */ - if (!check_dynamic_resources(state, - TaskSpec_get_required_resources(spec))) { - /* This task could not be satisfied -- proceed to the next task. */ - ++it; - continue; - } - - /* Dispatch this task to an available worker and dequeue the task. */ - RAY_LOG(DEBUG) << "Dispatching task"; - /* Get the last available worker in the available worker queue. */ - LocalSchedulerClient *worker = algorithm_state->available_workers.back(); - /* Tell the available worker to execute the task. */ - assign_task_to_worker(state, *it, worker); - /* Remove the worker from the available queue, and add it to the executing - * workers. */ - algorithm_state->available_workers.pop_back(); - algorithm_state->executing_workers.push_back(worker); - print_resource_info(state, spec); - /* Dequeue the task. */ - it = algorithm_state->dispatch_task_queue->erase(it); - } /* End for each task in the dispatch queue. */ -} - -/** - * Attempt to dispatch both regular tasks and actor tasks. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @return Void. - */ -void dispatch_all_tasks(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state) { - /* First attempt to dispatch regular tasks. */ - dispatch_tasks(state, algorithm_state); - - /* Attempt to dispatch actor tasks. */ - auto it = algorithm_state->actors_with_pending_tasks.begin(); - while (it != algorithm_state->actors_with_pending_tasks.end()) { - // We cannot short-circuit and exit here if there are no resources - // available because actor methods may require 0 CPUs. - - /* We increment the iterator ahead of time because the call to - * dispatch_actor_task may invalidate the current iterator. */ - ActorID actor_id = *it; - it++; - /* Dispatch tasks for the current actor. */ - dispatch_actor_task(state, algorithm_state, actor_id); - } -} - -/** - * A helper function to allocate a queue entry for a task specification and - * push it onto a generic queue. - * - * @param state The state of the local scheduler. - * @param task_queue A pointer to a task queue. NOTE: Because we are using - * utlist.h, we must pass in a pointer to the queue we want to append - * to. If we passed in the queue itself and the queue was empty, this - * would append the task to a queue that we don't have a reference to. - * @param task_entry A pointer to the task entry to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return A reference to the entry in the queue that was pushed. - */ -std::list::iterator queue_task( - LocalSchedulerState *state, - std::list *task_queue, - TaskExecutionSpec &task_entry, - bool from_global_scheduler) { - /* The task has been added to a local scheduler queue. Write the entry in the - * task table to notify others that we have queued it. */ - if (state->db != NULL) { - Task *task = - Task_alloc(task_entry, TaskStatus::QUEUED, get_db_client_id(state->db)); - if (from_global_scheduler) { - /* If the task is from the global scheduler, it's already been added to - * the task table, so just update the entry. */ - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - /* Otherwise, this is the first time the task has been seen in the system - * (unless it's a resubmission of a previous task), so add the entry. */ - task_table_add_task(state->db, task, NULL, NULL, NULL); - } - } - - /* Copy the spec and add it to the task queue. The allocated spec will be - * freed when it is assigned to a worker. */ - TaskExecutionSpec copy = TaskExecutionSpec(&task_entry); - task_queue->push_back(std::move(copy)); - /* Since we just queued the task, we can get a reference to it by going to - * the last element in the queue. */ - auto it = task_queue->end(); - --it; - - return it; -} - -/** - * Queue a task whose dependencies are missing. When the task's object - * dependencies become available, the task will be moved to the dispatch queue. - * If we have a connection to a plasma manager, begin trying to fetch the - * dependencies. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_waiting_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - /* For actor tasks, do not queue tasks that have already been executed. */ - auto spec = execution_spec.Spec(); - if (!TaskSpec_actor_id(spec).is_nil()) { - auto entry = - algorithm_state->local_actor_infos.find(TaskSpec_actor_id(spec)); - if (entry != algorithm_state->local_actor_infos.end()) { - /* Find the highest task counter with the same handle ID as the task to - * queue. */ - auto &task_counters = entry->second.task_counters; - auto task_counter = task_counters.find(TaskSpec_actor_handle_id(spec)); - if (task_counter != task_counters.end() && - TaskSpec_actor_counter(spec) < task_counter->second) { - /* If the task to queue has a lower task counter, do not queue it. */ - RAY_LOG(INFO) << "A task that has already been executed has been " - << "resubmitted, so we are ignoring it. This should only " - << "happen during reconstruction."; - return; - } - } - } - - RAY_LOG(DEBUG) << "Queueing task in waiting queue"; - auto it = queue_task(state, algorithm_state->waiting_task_queue, - execution_spec, from_global_scheduler); - fetch_missing_dependencies(state, algorithm_state, it); -} - -/** - * Queue a task whose dependencies are ready. When the task reaches the front - * of the dispatch queue and workers are available, it will be assigned. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_dispatch_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - RAY_LOG(DEBUG) << "Queueing task in dispatch queue"; - TaskSpec *spec = execution_spec.Spec(); - if (TaskSpec_is_actor_task(spec)) { - queue_actor_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } else { - queue_task(state, algorithm_state->dispatch_task_queue, execution_spec, - from_global_scheduler); - } -} - -/** - * Add the task to the proper local scheduler queue. This assumes that the - * scheduling decision to place the task on this node has already been made, - * whether locally or by the global scheduler. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_task_locally(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - if (can_run(algorithm_state, execution_spec)) { - /* Dependencies are ready, so push the task to the dispatch queue. */ - queue_dispatch_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } else { - /* Dependencies are not ready, so push the task to the waiting queue. */ - queue_waiting_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } -} - -void give_task_to_local_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::SCHEDULED); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - RAY_CHECK(TaskSpec_is_actor_task(spec)); - - ActorID actor_id = TaskSpec_actor_id(spec); - - if (state->actor_mapping.count(actor_id) == 0) { - // Process the actor task submission again. This will cache the task - // locally until a new actor creation notification is broadcast. We will - // attempt to reissue the actor creation tasks for all cached actor tasks - // in rerun_actor_creation_tasks_timeout_handler. - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - return; - } - - DBClientID remote_local_scheduler_id = - state->actor_mapping[actor_id].local_scheduler_id; - - // TODO(rkn): db_client_table_cache_get is a blocking call, is this a - // performance issue? - DBClient remote_local_scheduler = - db_client_table_cache_get(state->db, remote_local_scheduler_id); - - // Check if the local scheduler that we're assigning this task to is still - // alive. - if (remote_local_scheduler.is_alive) { - // The local scheduler is still alive, which means that perhaps it hasn't - // subscribed to the appropriate channel yet, so retrying should suffice. - // This should be rare. - give_task_to_local_scheduler( - state, state->algorithm_state, *execution_spec, - state->actor_mapping[actor_id].local_scheduler_id); - } else { - // The local scheduler is dead, so we will need to recreate the actor by - // invoking reconstruction. - RAY_LOG(INFO) << "Local scheduler " << remote_local_scheduler_id - << " that was running actor " << actor_id << " died."; - RAY_CHECK(state->actor_mapping.count(actor_id) == 1); - // Update the actor mapping. - state->actor_mapping.erase(actor_id); - // Process the actor task submission again. This will cache the task - // locally until a new actor creation notification is broadcast. We will - // attempt to reissue the actor creation tasks for all cached actor tasks - // in rerun_actor_creation_tasks_timeout_handler. - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - } -} - -/** - * Give a task directly to another local scheduler. This is currently only used - * for assigning actor tasks to the local scheduler responsible for that actor. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to schedule. - * @param local_scheduler_id The ID of the local scheduler to give the task to. - * @return Void. - */ -void give_task_to_local_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - DBClientID local_scheduler_id) { - if (local_scheduler_id == get_db_client_id(state->db)) { - RAY_LOG(WARNING) << "Local scheduler is trying to assign a task to itself."; - } - RAY_CHECK(state->db != NULL); - /* Assign the task to the relevant local scheduler. */ - RAY_CHECK(state->config.global_scheduler_exists); - Task *task = - Task_alloc(execution_spec, TaskStatus::SCHEDULED, local_scheduler_id); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = give_task_to_local_scheduler_retry, - }; - - task_table_add_task(state->db, task, &retryInfo, NULL, state); -} - -void give_task_to_global_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::WAITING); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - RAY_CHECK(!TaskSpec_is_actor_task(spec)); - - give_task_to_global_scheduler(state, state->algorithm_state, *execution_spec); -} - -/** - * Give a task to the global scheduler to schedule. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to schedule. - * @return Void. - */ -void give_task_to_global_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - if (state->db == NULL || !state->config.global_scheduler_exists) { - /* A global scheduler is not available, so queue the task locally. */ - queue_task_locally(state, algorithm_state, execution_spec, false); - return; - } - /* Pass on the task to the global scheduler. */ - RAY_CHECK(state->config.global_scheduler_exists); - Task *task = Task_alloc(execution_spec, TaskStatus::WAITING, - get_db_client_id(state->db)); - RAY_CHECK(state->db != NULL); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = give_task_to_global_scheduler_retry, - }; - task_table_add_task(state->db, task, &retryInfo, NULL, state); -} - -bool resource_constraints_satisfied(LocalSchedulerState *state, - TaskSpec *spec) { - /* At the local scheduler, if required resource vector exceeds either static - * or dynamic resource vector, the resource constraint is not satisfied. */ - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - double required_resource = resource_pair.second; - if (required_resource > state->static_resources[resource_pair.first] || - required_resource > state->dynamic_resources[resource_pair.first]) { - return false; - } - } - - if (TaskSpec_is_actor_creation_task(spec) && - state->static_resources["CPU"] != 0) { - return false; - } - - return true; -} - -void handle_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - /* TODO(atumanov): if static is satisfied and local objects ready, but dynamic - * resource is currently unavailable, then consider queueing task locally and - * recheck dynamic next time. */ - - // If this task's constraints are satisfied, dependencies are available - // locally, and there is an available worker, then enqueue the task in the - // dispatch queue and trigger task dispatch. Otherwise, pass the task along to - // the global scheduler if there is one. - // Note that actor creation tasks automatically go to the global scheduler. - // See https://github.com/ray-project/ray/issues/1756 for more discussion. - // This is a hack to improve actor load balancing (and to prevent the scenario - // where all actors are started locally). - if (resource_constraints_satisfied(state, spec) && - (algorithm_state->available_workers.size() > 0) && - can_run(algorithm_state, execution_spec) && - !TaskSpec_is_actor_creation_task(spec)) { - queue_dispatch_task(state, algorithm_state, execution_spec, false); - } else { - /* Give the task to the global scheduler to schedule, if it exists. */ - give_task_to_global_scheduler(state, algorithm_state, execution_spec); - } - - /* Try to dispatch tasks, since we may have added one to the queue. */ - dispatch_tasks(state, algorithm_state); -} - -void handle_actor_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *task_spec = execution_spec.Spec(); - RAY_CHECK(TaskSpec_is_actor_task(task_spec)); - ActorID actor_id = TaskSpec_actor_id(task_spec); - - if (state->actor_mapping.count(actor_id) == 0) { - // Create a copy of the task to write to the task table. - Task *task = Task_alloc( - task_spec, execution_spec.SpecSize(), TaskStatus::ACTOR_CACHED, - get_db_client_id(state->db), execution_spec.ExecutionDependencies()); - - /* Add this task to a queue of tasks that have been submitted but the local - * scheduler doesn't know which actor is responsible for them. These tasks - * will be resubmitted (internally by the local scheduler) whenever a new - * actor notification arrives. NOTE(swang): These tasks have not yet been - * added to the task table. */ - TaskExecutionSpec task_entry = TaskExecutionSpec(&execution_spec); - algorithm_state->cached_submitted_actor_tasks.push_back( - std::move(task_entry)); - - // Even if the task can't be assigned to a worker yet, we should still write - // it to the task table. TODO(rkn): There's no need to do this more than - // once, and we could run into problems if we have very large numbers of - // tasks in this cache. - task_table_add_task(state->db, task, NULL, NULL, NULL); - - return; - } - - if (state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)) { - /* This local scheduler is responsible for the actor, so handle the task - * locally. */ - queue_task_locally(state, algorithm_state, execution_spec, false); - /* Attempt to dispatch tasks to this actor. */ - dispatch_actor_task(state, algorithm_state, actor_id); - } else { - /* This local scheduler is not responsible for the task, so find the local - * scheduler that is responsible for this actor and assign the task directly - * to that local scheduler. */ - give_task_to_local_scheduler( - state, algorithm_state, execution_spec, - state->actor_mapping[actor_id].local_scheduler_id); - } -} - -void handle_actor_creation_notification( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - int num_cached_actor_tasks = - algorithm_state->cached_submitted_actor_tasks.size(); - - for (int i = 0; i < num_cached_actor_tasks; ++i) { - TaskExecutionSpec &task = algorithm_state->cached_submitted_actor_tasks[i]; - /* Note that handle_actor_task_submitted may append the spec to the end of - * the cached_submitted_actor_tasks array. */ - handle_actor_task_submitted(state, algorithm_state, task); - } - /* Remove all the tasks that were resubmitted. This does not erase the tasks - * that were newly appended to the cached_submitted_actor_tasks array. */ - auto begin = algorithm_state->cached_submitted_actor_tasks.begin(); - algorithm_state->cached_submitted_actor_tasks.erase( - begin, begin + num_cached_actor_tasks); -} - -void handle_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - /* This callback handles tasks that were assigned to this local scheduler by - * the global scheduler, so we can safely assert that there is a connection to - * the database. */ - RAY_CHECK(state->db != NULL); - RAY_CHECK(state->config.global_scheduler_exists); - - // Currently, the global scheduler will never assign a task to a local - // scheduler that has 0 CPUs. - RAY_CHECK(state->static_resources["CPU"] != 0); - - // Push the task to the appropriate queue. - queue_task_locally(state, algorithm_state, execution_spec, true); - dispatch_tasks(state, algorithm_state); -} - -void handle_actor_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - /* This callback handles tasks that were assigned to this local scheduler by - * the global scheduler or by other workers, so we can safely assert that - * there is a connection to the database. */ - RAY_CHECK(state->db != NULL); - RAY_CHECK(state->config.global_scheduler_exists); - /* Check that the task is meant to run on an actor that this local scheduler - * is responsible for. */ - RAY_CHECK(TaskSpec_is_actor_task(spec)); - ActorID actor_id = TaskSpec_actor_id(spec); - if (state->actor_mapping.count(actor_id) == 1) { - RAY_CHECK(state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)); - } else { - /* This means that an actor has been assigned to this local scheduler, and a - * task for that actor has been received by this local scheduler, but this - * local scheduler has not yet processed the notification about the actor - * creation. This may be possible though should be very uncommon. If it does - * happen, it's ok. */ - RAY_LOG(INFO) << "handle_actor_task_scheduled called on local scheduler " - << "but the corresponding actor_map_entry is not present. " - << "This should be rare."; - } - /* Push the task to the appropriate queue. */ - queue_task_locally(state, algorithm_state, execution_spec, true); - dispatch_actor_task(state, algorithm_state, actor_id); -} - -void handle_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - RAY_CHECK(worker->task_in_progress == NULL); - /* Check that the worker isn't in the pool of available workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->available_workers, worker)); - - /* Check that the worker isn't in the list of blocked workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* If the worker was executing a task, it must have finished, so remove it - * from the list of executing workers. If the worker is connecting for the - * first time, it will not be in the list of executing workers. */ - remove_worker_from_vector(algorithm_state->executing_workers, worker); - /* Double check that we successfully removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Add worker to the list of available workers. */ - algorithm_state->available_workers.push_back(worker); - - /* Try to dispatch tasks. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Make sure this is not an actor. */ - RAY_CHECK(worker->actor_id.is_nil()); - - /* Make sure that we remove the worker at most once. */ - int num_times_removed = 0; - - /* Remove the worker from available workers, if it's there. */ - bool removed_from_available = - remove_worker_from_vector(algorithm_state->available_workers, worker); - num_times_removed += removed_from_available; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->available_workers, worker)); - - /* Remove the worker from executing workers, if it's there. */ - bool removed_from_executing = - remove_worker_from_vector(algorithm_state->executing_workers, worker); - num_times_removed += removed_from_executing; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Remove the worker from blocked workers, if it's there. */ - bool removed_from_blocked = - remove_worker_from_vector(algorithm_state->blocked_workers, worker); - num_times_removed += removed_from_blocked; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* Make sure we removed the worker at most once. */ - RAY_CHECK(num_times_removed <= 1); - - /* Attempt to dispatch some tasks because some resources may have freed up. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_actor_worker_disconnect(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker, - bool cleanup) { - /* Fail all in progress or queued tasks of the actor. */ - if (!cleanup) { - if (state->db != NULL) { - actor_table_mark_removed(state->db, worker->actor_id); - } - - if (worker->task_in_progress != NULL) { - finish_killed_task(state, - *Task_task_execution_spec(worker->task_in_progress)); - } - - state->removed_actors.insert(worker->actor_id); - - RAY_CHECK(algorithm_state->local_actor_infos.count(worker->actor_id) != 0); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(worker->actor_id)->second; - for (auto &task : *entry.task_queue) { - finish_killed_task(state, task); - } - } - - remove_actor(algorithm_state, worker->actor_id); - - /* Attempt to dispatch some tasks because some resources may have freed up. */ - dispatch_all_tasks(state, algorithm_state); - - /* Start a worker to replace the removed actor's worker and replenish the - * worker pool. */ - start_worker(state); -} - -/* NOTE(swang): For tasks that saved a checkpoint, we should consider - * overwriting the result table entries for the current task frontier to - * avoid duplicate task submissions during reconstruction. */ -void handle_actor_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - ActorID actor_id = worker->actor_id; - RAY_CHECK(!actor_id.is_nil()); - /* Get the actor info for this worker. */ - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 1); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - RAY_CHECK(worker == entry.worker); - RAY_CHECK(!entry.worker_available); - /* If an actor task was assigned, mark returned dummy object as locally - * available. This is not added to the object table, so the update will be - * invisible to other nodes. */ - /* NOTE(swang): These objects are never cleaned up. We should consider - * removing the objects, e.g., when an actor is terminated. */ - if (!entry.execution_dependency.is_nil()) { - handle_object_available(state, algorithm_state, entry.execution_dependency); - } - /* Unset the fields indicating an assigned task. */ - entry.worker_available = true; - /* Assign new tasks if possible. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Find the worker in the list of executing workers. */ - RAY_CHECK( - remove_worker_from_vector(algorithm_state->executing_workers, worker)); - - /* Check that the worker isn't in the list of blocked workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* Add the worker to the list of blocked workers. */ - algorithm_state->blocked_workers.push_back(worker); - - /* Try to dispatch tasks, since we may have freed up some resources. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_actor_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* The actor case doesn't use equivalents of the blocked_workers and - * executing_workers lists. Are these necessary? */ - /* Try to dispatch tasks, since we may have freed up some resources. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Find the worker in the list of blocked workers. */ - RAY_CHECK( - remove_worker_from_vector(algorithm_state->blocked_workers, worker)); - - /* Check that the worker isn't in the list of executing workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Add the worker to the list of executing workers. */ - algorithm_state->executing_workers.push_back(worker); -} - -void handle_actor_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) {} - -void handle_object_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ObjectID object_id) { - auto object_entry_it = algorithm_state->remote_objects.find(object_id); - - ObjectEntry entry; - /* Get the entry for this object from the active fetch request, or allocate - * one if needed. */ - if (object_entry_it != algorithm_state->remote_objects.end()) { - /* Remove the object from the active fetch requests. */ - entry = object_entry_it->second; - algorithm_state->remote_objects.erase(object_id); - } - - /* Add the entry to the set of locally available objects. */ - RAY_CHECK(algorithm_state->local_objects.count(object_id) == 0); - algorithm_state->local_objects[object_id] = entry; - - if (!entry.dependent_tasks.empty()) { - /* Out of the tasks that were dependent on this object, if they are now - * ready to run, move them to the dispatch queue. */ - for (auto &it : entry.dependent_tasks) { - if (can_run(algorithm_state, *it)) { - if (TaskSpec_is_actor_task(it->Spec())) { - insert_actor_task_queue(state, algorithm_state, std::move(*it)); - } else { - algorithm_state->dispatch_task_queue->push_back(std::move(*it)); - } - /* Remove the entry with a matching TaskSpec pointer from the waiting - * queue, but do not free the task spec. */ - algorithm_state->waiting_task_queue->erase(it); - } - } - /* Try to dispatch tasks, since we may have added some from the waiting - * queue. */ - dispatch_all_tasks(state, algorithm_state); - /* Clean up the records for dependent tasks. */ - entry.dependent_tasks.clear(); - } -} - -void handle_object_removed(LocalSchedulerState *state, - ObjectID removed_object_id) { - /* Remove the object from the set of locally available objects. */ - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - RAY_CHECK(algorithm_state->local_objects.count(removed_object_id) == 1); - algorithm_state->local_objects.erase(removed_object_id); - - /* Track queued tasks that were dependent on this object. - * NOTE: Since objects often get removed in batches (e.g., during eviction), - * we may end up iterating through the queues many times in a row. If this - * turns out to be a bottleneck, consider tracking dependencies even for - * tasks in the dispatch queue, or batching object notifications. */ - /* Track the dependency for tasks that were in the dispatch queue. Remove - * these tasks from the dispatch queue and push them to the waiting queue. */ - for (auto it = algorithm_state->dispatch_task_queue->begin(); - it != algorithm_state->dispatch_task_queue->end();) { - if (it->DependsOn(removed_object_id)) { - /* This task was dependent on the removed object. */ - RAY_LOG(DEBUG) << "Moved task from dispatch queue back to waiting queue"; - algorithm_state->waiting_task_queue->push_back(std::move(*it)); - /* Remove the task from the dispatch queue, but do not free the task - * spec. */ - it = algorithm_state->dispatch_task_queue->erase(it); - } else { - /* The task can still run, so continue to the next task. */ - ++it; - } - } - - std::vector empty_actor_queues; - for (auto it = algorithm_state->actors_with_pending_tasks.begin(); - it != algorithm_state->actors_with_pending_tasks.end(); it++) { - auto actor_info = algorithm_state->local_actor_infos[*it]; - for (auto queue_it = actor_info.task_queue->begin(); - queue_it != actor_info.task_queue->end();) { - if (queue_it->DependsOn(removed_object_id)) { - /* This task was dependent on the removed object. */ - RAY_LOG(DEBUG) << "Moved task from actor dispatch queue back to " - << "waiting queue"; - algorithm_state->waiting_task_queue->push_back(std::move(*queue_it)); - /* Remove the task from the dispatch queue, but do not free the task - * spec. */ - queue_it = actor_info.task_queue->erase(queue_it); - if (actor_info.task_queue->size() == 0) { - empty_actor_queues.push_back(*it); - } - } else { - ++queue_it; - } - } - } - for (auto actor_id : empty_actor_queues) { - algorithm_state->actors_with_pending_tasks.erase(actor_id); - } - - /* Track the dependency for tasks that are in the waiting queue, including - * those that were just moved from the dispatch queue. */ - for (auto it = algorithm_state->waiting_task_queue->begin(); - it != algorithm_state->waiting_task_queue->end(); ++it) { - int64_t num_dependencies = it->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID dependency_id = it->DependencyId(i, j); - if (dependency_id == removed_object_id) { - /* Do not request a transfer from other plasma managers if this is an - * execution dependency. */ - bool request_transfer = it->IsStaticDependency(i); - fetch_missing_dependency(state, algorithm_state, it, - removed_object_id.to_plasma_id(), - request_transfer); - } - } - } - } -} - -void handle_driver_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - WorkerID driver_id) { - /* Loop over fetch requests. This must be done before we clean up the waiting - * task queue and the dispatch task queue because this map contains iterators - * for those lists, which will be invalidated when we clean up those lists.*/ - for (auto it = algorithm_state->remote_objects.begin(); - it != algorithm_state->remote_objects.end();) { - /* Loop over the tasks that are waiting for this object and remove the tasks - * for the removed driver. */ - auto task_it_it = it->second.dependent_tasks.begin(); - while (task_it_it != it->second.dependent_tasks.end()) { - /* If the dependent task was a task for the removed driver, remove it from - * this vector. */ - TaskSpec *spec = (*task_it_it)->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - task_it_it = it->second.dependent_tasks.erase(task_it_it); - } else { - task_it_it++; - } - } - /* If there are no more dependent tasks for this object, then remove the - * ObjectEntry. */ - if (it->second.dependent_tasks.size() == 0) { - it = algorithm_state->remote_objects.erase(it); - } else { - it++; - } - } - - /* Remove this driver's tasks from the waiting task queue. */ - auto it = algorithm_state->waiting_task_queue->begin(); - while (it != algorithm_state->waiting_task_queue->end()) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->waiting_task_queue->erase(it); - } else { - it++; - } - } - - /* Remove this driver's tasks from the dispatch task queue. */ - it = algorithm_state->dispatch_task_queue->begin(); - while (it != algorithm_state->dispatch_task_queue->end()) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->dispatch_task_queue->erase(it); - } else { - it++; - } - } - - // Remove this driver's tasks from the cached actor tasks. Note that this loop - // could be very slow if the vector of cached actor tasks is very long. - for (auto it = algorithm_state->cached_submitted_actor_tasks.begin(); - it != algorithm_state->cached_submitted_actor_tasks.end();) { - TaskSpec *spec = (*it).Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->cached_submitted_actor_tasks.erase(it); - } else { - ++it; - } - } - - /* TODO(rkn): Should we clean up the actor data structures? */ -} - -int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state) { - return algorithm_state->waiting_task_queue->size(); -} - -int num_dispatch_tasks(SchedulingAlgorithmState *algorithm_state) { - return algorithm_state->dispatch_task_queue->size(); -} - -void print_worker_info(const char *message, - SchedulingAlgorithmState *algorithm_state) { - RAY_LOG(DEBUG) << message << ": " << algorithm_state->available_workers.size() - << " available, " << algorithm_state->executing_workers.size() - << " executing, " << algorithm_state->blocked_workers.size() - << " blocked"; -} - -std::unordered_map get_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - return algorithm_state->local_actor_infos[actor_id].task_counters; -} - -void set_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &task_counters) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - /* Overwrite the current task counters for the actor. This is necessary - * during reconstruction when resuming from a checkpoint so that we can - * resume the task frontier at the time that the checkpoint was saved. */ - auto &entry = algorithm_state->local_actor_infos[actor_id]; - entry.task_counters = task_counters; - - /* Filter out tasks for the actor that were submitted earlier than the new - * task counter. These represent tasks that executed before the actor's - * resumed checkpoint, and therefore should not be re-executed. */ - for (auto it = entry.task_queue->begin(); it != entry.task_queue->end();) { - /* Filter out duplicate tasks for the actor that are runnable. */ - TaskSpec *pending_task_spec = it->Spec(); - ActorHandleID handle_id = TaskSpec_actor_handle_id(pending_task_spec); - auto task_counter = entry.task_counters.find(handle_id); - if (task_counter != entry.task_counters.end() && - TaskSpec_actor_counter(pending_task_spec) < task_counter->second) { - /* If the task's counter is less than the highest count for that handle, - * then remove it from the actor's runnable queue. */ - it = entry.task_queue->erase(it); - } else { - it++; - } - } - for (auto it = algorithm_state->waiting_task_queue->begin(); - it != algorithm_state->waiting_task_queue->end();) { - /* Filter out duplicate tasks for the actor that are waiting on a missing - * dependency. */ - TaskSpec *spec = it->Spec(); - if (TaskSpec_actor_id(spec) == actor_id && - TaskSpec_actor_counter(spec) < - entry.task_counters[TaskSpec_actor_handle_id(spec)]) { - /* If the waiting task is for the same actor and its task counter is less - * than the highest count for that handle, then clear its object - * dependencies and remove it from the queue. */ - clear_missing_dependencies(algorithm_state, it); - it = algorithm_state->waiting_task_queue->erase(it); - } else { - it++; - } - } -} - -std::unordered_map get_actor_frontier( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - return algorithm_state->local_actor_infos[actor_id].frontier_dependencies; -} - -void set_actor_frontier( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &frontier_dependencies) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - auto entry = algorithm_state->local_actor_infos[actor_id]; - entry.frontier_dependencies = frontier_dependencies; - for (auto frontier_dependency : entry.frontier_dependencies) { - if (algorithm_state->local_objects.count(frontier_dependency.second) == 0) { - handle_object_available(state, algorithm_state, - frontier_dependency.second); - } - } -} diff --git a/src/local_scheduler/local_scheduler_algorithm.h b/src/local_scheduler/local_scheduler_algorithm.h deleted file mode 100644 index 9238d5db58e55..0000000000000 --- a/src/local_scheduler/local_scheduler_algorithm.h +++ /dev/null @@ -1,438 +0,0 @@ -#ifndef LOCAL_SCHEDULER_ALGORITHM_H -#define LOCAL_SCHEDULER_ALGORITHM_H - -#include "local_scheduler_shared.h" -#include "common/task.h" -#include "state/local_scheduler_table.h" - -/* ==== The scheduling algorithm ==== - * - * This file contains declaration for all functions and data structures - * that need to be provided if you want to implement a new algorithms - * for the local scheduler. - * - */ - -/** - * Initialize the scheduler state. - * - * @return State managed by the scheduling algorithm. - */ -SchedulingAlgorithmState *SchedulingAlgorithmState_init(void); - -/** - * Free the scheduler state. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return Void. - */ -void SchedulingAlgorithmState_free(SchedulingAlgorithmState *algorithm_state); - -/** - * - */ -void provide_scheduler_info(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerInfo *info); - -/** - * This function will be called when a new task is submitted by a worker for - * execution. The task will either be: - * 1. Put into the waiting queue, where it will wait for its dependencies to - * become available. - * 2. Put into the dispatch queue, where it will wait for an available worker. - * 3. Given to the global scheduler to be scheduled. - * - * Currently, the local scheduler policy is to keep the task if its - * dependencies are ready and there is an available worker. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is submitted by the worker. - * @return Void. - */ -void handle_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This version of handle_task_submitted is used when the task being submitted - * is a method of an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is submitted by the worker. - * @return Void. - */ -void handle_actor_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function will be called when the local scheduler receives a notification - * about the creation of a new actor. This can be used by the scheduling - * algorithm to resubmit cached actor tasks. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor being created. - * @return Void. - */ -void handle_actor_creation_notification( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * This function will be called when a task is assigned by the global scheduler - * for execution on this local scheduler. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is assigned by the global scheduler. - * @return Void. - */ -void handle_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function will be called when an actor task is assigned by the global - * scheduler or by another local scheduler for execution on this local - * scheduler. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is assigned by the global scheduler. - * @return Void. - */ -void handle_actor_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function is called if a new object becomes available in the local - * plasma store. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param object_id ID of the object that became available. - * @return Void. - */ -void handle_object_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ObjectID object_id); - -/** - * This function is called if an object is removed from the local plasma store. - * - * @param state The state of the local scheduler. - * @param object_id ID of the object that was removed. - * @return Void. - */ -void handle_object_removed(LocalSchedulerState *state, ObjectID object_id); - -/** - * This function is called when a new worker becomes available. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is available. - * @return Void. - */ -void handle_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when a worker is removed. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is removed. - * @return Void. - */ -void handle_worker_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This version of handle_worker_available is called whenever the worker that is - * available is running an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is available. - * @return Void. - */ -void handle_actor_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * Handle the fact that a new worker is available for running an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor running on the worker. - * @param initial_execution_dependency The dummy object ID of the actor - * creation task. - * @param worker The worker that was converted to an actor. - * @return Void. - */ -void handle_convert_worker_to_actor( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker); - -/** - * Handle the fact that a worker running an actor has disconnected. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that was disconnected. - * @param cleanup Whether the disconnect was during cleanup. - * @return Void. - */ -void handle_actor_worker_disconnect(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker, - bool cleanup); - -/** - * This function is called when a worker that was executing a task becomes - * blocked on an object that isn't available locally yet. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is blocked. - * @return Void. - */ -void handle_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when an actor that was executing a task becomes - * blocked on an object that isn't available locally yet. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is blocked. - * @return Void. - */ -void handle_actor_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when a worker that was blocked on an object that - * wasn't available locally yet becomes unblocked. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is now unblocked. - * @return Void. - */ -void handle_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when an actor that was blocked on an object that - * wasn't available locally yet becomes unblocked. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is now unblocked. - * @return Void. - */ -void handle_actor_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * Process the fact that a driver has been removed. This will remove all of the - * tasks for that driver from the scheduling algorithm's internal data - * structures. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param driver_id The ID of the driver that was removed. - * @return Void. - */ -void handle_driver_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - WorkerID driver_id); - -/** - * This function fetches queued task's missing object dependencies. It is - * called every local_scheduler_fetch_timeout_milliseconds. - * - * @param loop The local scheduler's event loop. - * @param id The ID of the timer that triggers this function. - * @param context The function's context. - * @return An integer representing the time interval in seconds before the - * next invocation of the function. - */ -int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context); - -/** - * This function initiates reconstruction for task's missing object - * dependencies. It is called every - * local_scheduler_reconstruction_timeout_milliseconds, but it may not initiate - * reconstruction for every missing object. - * - * @param loop The local scheduler's event loop. - * @param id The ID of the timer that triggers this function. - * @param context The function's context. - * @return An integer representing the time interval in seconds before the - * next invocation of the function. - */ -int reconstruct_object_timeout_handler(event_loop *loop, - timer_id id, - void *context); - -/// This function initiates reconstruction for the actor creation tasks -/// corresponding to the actor tasks cached in the local scheduler. -/// -/// \param loop The local scheduler's event loop. -/// \param id The ID of the timer that triggers this function. -/// \param context The function's context. -/// \return An integer representing the time interval in seconds before the -/// next invocation of the function. -int rerun_actor_creation_tasks_timeout_handler(event_loop *loop, - timer_id id, - void *context); - -/** - * Check whether an object, including actor dummy objects, is locally - * available. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param object_id The ID of the object to check for. - * @return A bool representing whether the object is locally available. - */ -bool object_locally_available(SchedulingAlgorithmState *algorithm_state, - ObjectID object_id); - -/// Spill some tasks back to the global scheduler. This function implements the -/// spillback policy. -/// -/// @param state The scheduler state. -/// @return Void. -void spillback_tasks_handler(LocalSchedulerState *state); - -/** - * A helper function to print debug information about the current state and - * number of workers. - * - * @param message A message to identify the log message. - * @param algorithm_state State maintained by the scheduling algorithm. - * @return Void. - */ -void print_worker_info(const char *message, - SchedulingAlgorithmState *algorithm_state); - -/* - * The actor frontier consists of the number of tasks executed so far and the - * execution dependencies required by the current runnable tasks, according to - * the actor's local scheduler. Since an actor may have multiple handles, the - * tasks submitted to the actor form a DAG, where nodes are tasks and edges are - * execution dependencies. The frontier is a cut across this DAG. The number of - * tasks so far is the number of nodes included in the DAG root's partition. - * - * The actor gets the current frontier of tasks from the local scheduler during - * a checkpoint save, so that it can save the point in the actor's lifetime at - * which the checkpoint was taken. If the actor later resumes from that - * checkpoint, the actor can set the current frontier of tasks in the local - * scheduler so that the same frontier of tasks can be made runnable again - * during reconstruction, and so that we do not duplicate execution of tasks - * that already executed before the checkpoint. - */ - -/** - * Get the number of tasks, per actor handle, that have been executed on an - * actor so far. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @return A map from handle ID to the number of tasks submitted by that handle - * that have executed so far. - */ -std::unordered_map get_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * Set the number of tasks, per actor handle, that have been executed on an - * actor so far. All previous counts will be overwritten. Tasks that are - * waiting or runnable on the local scheduler that have a lower task count will - * be discarded, so that we don't duplicate execution. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @param task_counters A map from handle ID to the number of tasks submitted - * by that handle that have executed so far. - * @return Void. - */ -void set_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &task_counters); - -/** - * Get the actor's frontier of task dependencies. - * NOTE(swang): The returned frontier only includes handles known by the local - * scheduler. It does not include handles for which the local scheduler has not - * seen a runnable task yet. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @return A map from handle ID to execution dependency for the earliest - * runnable task submitted through that handle. - */ -std::unordered_map get_actor_frontier( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * Set the actor's frontier of task dependencies. The previous frontier will be - * overwritten. Any tasks that have an execution dependency on the new frontier - * (and that have all other dependencies fulfilled) will become runnable. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @param frontier_dependencies A map from handle ID to execution dependency - * for the earliest runnable task submitted through that handle. - * @return Void. - */ -void set_actor_frontier( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &frontier_dependencies); - -/** The following methods are for testing purposes only. */ -#ifdef LOCAL_SCHEDULER_TEST -/** - * Get the number of tasks currently waiting for object dependencies to become - * available locally. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return The number of tasks queued. - */ -int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state); - -/** - * Get the number of tasks currently waiting for a worker to become available. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return The number of tasks queued. - */ -int num_dispatch_tasks(SchedulingAlgorithmState *algorithm_state); -#endif - -#endif /* LOCAL_SCHEDULER_ALGORITHM_H */ diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc deleted file mode 100644 index 91b5fa9c9df1a..0000000000000 --- a/src/local_scheduler/local_scheduler_client.cc +++ /dev/null @@ -1,378 +0,0 @@ -#include "local_scheduler_client.h" - -#include "common_protocol.h" -#include "format/local_scheduler_generated.h" -#include "ray/raylet/format/node_manager_generated.h" - -#include "common/io.h" -#include "common/task.h" -#include -#include -#include - -using MessageType = ray::local_scheduler::protocol::MessageType; - -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, - const UniqueID &client_id, - bool is_worker, - const JobID &driver_id, - bool use_raylet, - const Language &language) { - LocalSchedulerConnection *result = new LocalSchedulerConnection(); - result->use_raylet = use_raylet; - result->conn = connect_ipc_sock_retry(local_scheduler_socket, -1, -1); - - /* Register with the local scheduler. - * NOTE(swang): If the local scheduler exits and we are registered as a - * worker, we will get killed. */ - flatbuffers::FlatBufferBuilder fbb; - if (use_raylet) { - auto message = ray::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), - to_flatbuf(fbb, driver_id), language); - fbb.Finish(message); - } else { - auto message = ray::local_scheduler::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), - to_flatbuf(fbb, driver_id)); - fbb.Finish(message); - } - /* Register the process ID with the local scheduler. */ - int success = write_message( - result->conn, static_cast(MessageType::RegisterClientRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &result->write_mutex); - RAY_CHECK(success == 0) << "Unable to register worker with local scheduler"; - - return result; -} - -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn) { - close(conn->conn); - delete conn; -} - -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateDisconnectClient(fbb); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::DisconnectClient), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_log_event(LocalSchedulerConnection *conn, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto key_string = fbb.CreateString((char *) key, key_length); - auto value_string = fbb.CreateString((char *) value, value_length); - auto message = ray::local_scheduler::protocol::CreateEventLogMessage( - fbb, key_string, value_string, timestamp); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::EventLogMessage), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_submit(LocalSchedulerConnection *conn, - const TaskExecutionSpec &execution_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = - to_flatbuf(fbb, execution_spec.ExecutionDependencies()); - auto task_spec = - fbb.CreateString(reinterpret_cast(execution_spec.Spec()), - execution_spec.SpecSize()); - auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies, task_spec); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::SubmitTask), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_submit_raylet( - LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); - auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb)); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::SubmitTask), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn, - int64_t *task_size) { - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, static_cast(MessageType::GetTask), 0, - NULL, &conn->write_mutex); - /* Receive a task from the local scheduler. This will block until the local - * scheduler gives this client a task. */ - read_message(conn->conn, &type, &reply_size, &reply); - } - if (type == static_cast(CommonMessageType::DISCONNECT_CLIENT)) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(static_cast(type) == MessageType::ExecuteTask); - - /* Parse the flatbuffer object. */ - auto reply_message = - flatbuffers::GetRoot(reply); - - /* Create a copy of the task spec so we can free the reply. */ - *task_size = reply_message->task_spec()->size(); - TaskSpec *data = (TaskSpec *) reply_message->task_spec()->data(); - TaskSpec *spec = TaskSpec_copy(data, *task_size); - - // Set the GPU IDs for this task. We only do this for non-actor tasks because - // for actors the GPUs are associated with the actor itself and not with the - // actor methods. Note that this also processes GPUs for actor creation tasks. - if (!TaskSpec_is_actor_task(spec)) { - conn->gpu_ids.clear(); - for (size_t i = 0; i < reply_message->gpu_ids()->size(); ++i) { - conn->gpu_ids.push_back(reply_message->gpu_ids()->Get(i)); - } - } - - /* Free the original message from the local scheduler. */ - free(reply); - /* Return the copy of the task spec and pass ownership to the caller. */ - return spec; -} - -// This is temporarily duplicated from local_scheduler_get_task while we have -// the raylet and non-raylet code paths. -TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn, - int64_t *task_size) { - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, static_cast(MessageType::GetTask), 0, - NULL, &conn->write_mutex); - // Receive a task from the local scheduler. This will block until the local - // scheduler gives this client a task. - read_message(conn->conn, &type, &reply_size, &reply); - } - if (type == static_cast(CommonMessageType::DISCONNECT_CLIENT)) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(type == static_cast(MessageType::ExecuteTask)); - - // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply); - - // Create a copy of the task spec so we can free the reply. - *task_size = reply_message->task_spec()->size(); - const TaskSpec *data = - reinterpret_cast(reply_message->task_spec()->data()); - TaskSpec *spec = TaskSpec_copy(const_cast(data), *task_size); - - // Set the resource IDs for this task. - conn->resource_ids_.clear(); - for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); - ++i) { - auto const &fractional_resource_ids = - reply_message->fractional_resource_ids()->Get(i); - auto &acquired_resources = conn->resource_ids_[string_from_flatbuf( - *fractional_resource_ids->resource_name())]; - - size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); - size_t num_resource_fractions = - fractional_resource_ids->resource_fractions()->size(); - RAY_CHECK(num_resource_ids == num_resource_fractions); - RAY_CHECK(num_resource_ids > 0); - for (size_t j = 0; j < num_resource_ids; ++j) { - int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); - double resource_fraction = - fractional_resource_ids->resource_fractions()->Get(j); - if (num_resource_ids > 1) { - int64_t whole_fraction = resource_fraction; - RAY_CHECK(whole_fraction == resource_fraction); - } - acquired_resources.push_back( - std::make_pair(resource_id, resource_fraction)); - } - } - - // Free the original message from the local scheduler. - free(reply); - // Return the copy of the task spec and pass ownership to the caller. - return spec; -} - -void local_scheduler_task_done(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::TaskDone), 0, - NULL, &conn->write_mutex); -} - -void local_scheduler_reconstruct_objects( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only) { - flatbuffers::FlatBufferBuilder fbb; - auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = ray::local_scheduler::protocol::CreateReconstructObjects( - fbb, object_ids_message, fetch_only); - fbb.Finish(message); - write_message(conn->conn, - static_cast(MessageType::ReconstructObjects), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - /* TODO(swang): Propagate the error. */ -} - -void local_scheduler_log_message(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::EventLogMessage), - 0, NULL, &conn->write_mutex); -} - -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::NotifyUnblocked), - 0, NULL, &conn->write_mutex); -} - -void local_scheduler_put_object(LocalSchedulerConnection *conn, - TaskID task_id, - ObjectID object_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreatePutObject( - fbb, to_flatbuf(fbb, task_id), to_flatbuf(fbb, object_id)); - fbb.Finish(message); - - write_message(conn->conn, static_cast(MessageType::PutObject), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -const std::vector local_scheduler_get_actor_frontier( - LocalSchedulerConnection *conn, - ActorID actor_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateGetActorFrontierRequest( - fbb, to_flatbuf(fbb, actor_id)); - fbb.Finish(message); - int64_t type; - std::vector reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, - static_cast(MessageType::GetActorFrontierRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - - read_vector(conn->conn, &type, reply); - } - if (static_cast(type) == - CommonMessageType::DISCONNECT_CLIENT) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(static_cast(type) == - MessageType::GetActorFrontierReply); - return reply; -} - -void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, - const std::vector &frontier) { - write_message(conn->conn, static_cast(MessageType::SetActorFrontier), - frontier.size(), const_cast(frontier.data()), - &conn->write_mutex); -} - -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - int num_returns, - int64_t timeout_milliseconds, - bool wait_local) { - // Write request. - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, - wait_local); - fbb.Finish(message); - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, - static_cast(ray::protocol::MessageType::WaitRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - // Read result. - read_message(conn->conn, &type, &reply_size, &reply); - } - RAY_CHECK(static_cast(type) == - ray::protocol::MessageType::WaitReply); - auto reply_message = flatbuffers::GetRoot(reply); - // Convert result. - std::pair, std::vector> result; - auto found = reply_message->found(); - for (uint i = 0; i < found->size(); i++) { - ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); - result.first.push_back(object_id); - } - auto remaining = reply_message->remaining(); - for (uint i = 0; i < remaining->size(); i++) { - ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); - result.second.push_back(object_id); - } - /* Free the original message from the local scheduler. */ - free(reply); - return result; -} - -void local_scheduler_push_error(LocalSchedulerConnection *conn, - const JobID &job_id, - const std::string &type, - const std::string &error_message, - double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), - fbb.CreateString(error_message), timestamp); - fbb.Finish(message); - - write_message(conn->conn, static_cast( - ray::protocol::MessageType::PushErrorRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_push_profile_events( - LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events) { - flatbuffers::FlatBufferBuilder fbb; - - auto message = CreateProfileTableData(fbb, &profile_events); - fbb.Finish(message); - - write_message(conn->conn, - static_cast( - ray::protocol::MessageType::PushProfileEventsRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool local_only) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest( - fbb, local_only, to_flatbuf(fbb, object_ids)); - fbb.Finish(message); - - int success = write_message( - conn->conn, - static_cast( - ray::protocol::MessageType::FreeObjectsInObjectStoreRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - RAY_CHECK(success == 0) << "Failed to write message to raylet."; -} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h deleted file mode 100644 index bb4fdb345896c..0000000000000 --- a/src/local_scheduler/local_scheduler_client.h +++ /dev/null @@ -1,260 +0,0 @@ -#ifndef LOCAL_SCHEDULER_CLIENT_H -#define LOCAL_SCHEDULER_CLIENT_H - -#include - -#include "common/task.h" -#include "local_scheduler_shared.h" -#include "ray/raylet/task_spec.h" - -struct LocalSchedulerConnection { - /// True if we should use the raylet code path and false otherwise. - bool use_raylet; - /** File descriptor of the Unix domain socket that connects to local - * scheduler. */ - int conn; - /** The IDs of the GPUs that this client can use. NOTE(rkn): This is only used - * by legacy Ray and will be deprecated. */ - std::vector gpu_ids; - /// A map from resource name to the resource IDs that are currently reserved - /// for this worker. Each pair consists of the resource ID and the fraction - /// of that resource allocated for this worker. - std::unordered_map>> - resource_ids_; - /// A mutex to protect stateful operations of the local scheduler client. - std::mutex mutex; - /// A mutext to protect write operations of the local scheduler client. - std::mutex write_mutex; -}; - -/** - * Connect to the local scheduler. - * - * @param local_scheduler_socket The name of the socket to use to connect to the - * local scheduler. - * @param worker_id A unique ID to represent the worker. - * @param is_worker Whether this client is a worker. If it is a worker, an - * additional message will be sent to register as one. - * @param driver_id The ID of the driver. This is non-nil if the client is a - * driver. - * @param use_raylet True if we should use the raylet code path and false - * otherwise. - * @return The connection information. - */ -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, - const UniqueID &worker_id, - bool is_worker, - const JobID &driver_id, - bool use_raylet, - const Language &language); - -/** - * Disconnect from the local scheduler. - * - * @param conn Local scheduler connection information returned by - * LocalSchedulerConnection_init. - * @return Void. - */ -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn); - -/** - * Submit a task to the local scheduler. - * - * @param conn The connection information. - * @param execution_spec The execution spec for the task to submit. - * @return Void. - */ -void local_scheduler_submit(LocalSchedulerConnection *conn, - const TaskExecutionSpec &execution_spec); - -/// Submit a task using the raylet code path. -/// -/// \param The connection information. -/// \param The execution dependencies. -/// \param The task specification. -/// \return Void. -void local_scheduler_submit_raylet( - LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec); - -/** - * Notify the local scheduler that this client is disconnecting gracefully. This - * is used by actors to exit gracefully so that the local scheduler doesn't - * propagate an error message to the driver. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn); - -/** - * Log an event to the event log. This will call RPUSH key value. We use RPUSH - * instead of SET so that it is possible to flush the log multiple times with - * the same key (for example the key might be shared across logging calls in the - * same task on a worker). - * - * @param conn The connection information. - * @param key The key to store the event in. - * @param key_length The length of the key. - * @param value The value to store. - * @param value_length The length of the value. - * @param timestamp The time that the event is logged. - * @return Void. - */ -void local_scheduler_log_event(LocalSchedulerConnection *conn, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp); - -/** - * Get next task for this client. This will block until the scheduler assigns - * a task to this worker. This allocates and returns a task, and so the task - * must be freed by the caller. - * - * @todo When does this actually get freed? - * - * @param conn The connection information. - * @param task_size A pointer to fill out with the task size. - * @return The address of the assigned task. - */ -TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn, - int64_t *task_size); - -/// Get next task for this client. This will block until the scheduler assigns -/// a task to this worker. This allocates and returns a task, and so the task -/// must be freed by the caller. -/// -/// \param conn The connection information. -/// \param task_size A pointer to fill out with the task size. -/// \return The address of the assigned task. -TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn, - int64_t *task_size); - -/** - * Tell the local scheduler that the client has finished executing a task. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_task_done(LocalSchedulerConnection *conn); - -/** - * Tell the local scheduler to reconstruct or fetch objects. - * - * @param conn The connection information. - * @param object_ids The IDs of the objects to reconstruct. - * @param fetch_only Only fetch objects, do not reconstruct them. - * @return Void. - */ -void local_scheduler_reconstruct_objects( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only = false); - -/** - * Send a log message to the local scheduler. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_log_message(LocalSchedulerConnection *conn); - -/** - * Notify the local scheduler that this client (worker) is no longer blocked. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn); - -/** - * Record the mapping from object ID to task ID for put events. - * - * @param conn The connection information. - * @param task_id The ID of the task that called put. - * @param object_id The ID of the object being stored. - * @return Void. - */ -void local_scheduler_put_object(LocalSchedulerConnection *conn, - TaskID task_id, - ObjectID object_id); - -/** - * Get an actor's current task frontier. - * - * @param conn The connection information. - * @param actor_id The ID of the actor whose frontier is returned. - * @return A byte vector that can be traversed as an ActorFrontier flatbuffer. - */ -const std::vector local_scheduler_get_actor_frontier( - LocalSchedulerConnection *conn, - ActorID actor_id); - -/** - * Set an actor's current task frontier. - * - * @param conn The connection information. - * @param frontier An ActorFrontier flatbuffer to set the frontier to. - * @return Void. - */ -void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, - const std::vector &frontier); - -/// Wait for the given objects until timeout expires or num_return objects are -/// found. -/// -/// \param conn The connection information. -/// \param object_ids The objects to wait for. -/// \param num_returns The number of objects to wait for. -/// \param timeout_milliseconds Duration, in milliseconds, to wait before -/// returning. -/// \param wait_local Whether to wait for objects to appear on this node. -/// \return A pair with the first element containing the object ids that were -/// found, and the second element the objects that were not found. -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - int num_returns, - int64_t timeout_milliseconds, - bool wait_local); - -/// Push an error to the relevant driver. -/// -/// \param conn The connection information. -/// \param The ID of the job that the error is for. -/// \param The type of the error. -/// \param The error message. -/// \param The timestamp of the error. -/// \return Void. -void local_scheduler_push_error(LocalSchedulerConnection *conn, - const JobID &job_id, - const std::string &type, - const std::string &error_message, - double timestamp); - -/// Store some profile events in the GCS. -/// -/// \param conn The connection information. -/// \param profile_events A batch of profiling event information. -/// \return Void. -void local_scheduler_push_profile_events( - LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events); - -/// Free a list of objects from object stores. -/// -/// \param conn The connection information. -/// \param object_ids A list of ObjectsIDs to be deleted. -/// \param local_only Whether keep this request with local object store -/// or send it to all the object stores. -/// \return Void. -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool local_only); - -#endif diff --git a/src/local_scheduler/local_scheduler_shared.h b/src/local_scheduler/local_scheduler_shared.h deleted file mode 100644 index 572f14a6fdf73..0000000000000 --- a/src/local_scheduler/local_scheduler_shared.h +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef LOCAL_SCHEDULER_SHARED_H -#define LOCAL_SCHEDULER_SHARED_H - -#include "common/task.h" -#include "common/state/table.h" -#include "common/state/db.h" -#include "plasma/client.h" -#include "ray/gcs/client.h" - -#include -#include -#include -#include - -/** This struct is used to maintain a mapping from actor IDs to the ID of the - * local scheduler that is responsible for the actor. */ -struct ActorMapEntry { - /** The ID of the driver that created the actor. */ - WorkerID driver_id; - /** The ID of the local scheduler that is responsible for the actor. */ - DBClientID local_scheduler_id; -}; - -/** Internal state of the scheduling algorithm. */ -typedef struct SchedulingAlgorithmState SchedulingAlgorithmState; - -struct LocalSchedulerClient; - -/** A struct storing the configuration state of the local scheduler. This should - * consist of values that don't change over the lifetime of the local - * scheduler. */ -typedef struct { - /** The script to use when starting a new worker. */ - const char **start_worker_command; - /** Whether there is a global scheduler. */ - bool global_scheduler_exists; -} local_scheduler_config; - -/** The state of the local scheduler. */ -struct LocalSchedulerState { - /** The configuration for the local scheduler. */ - local_scheduler_config config; - /** The local scheduler event loop. */ - event_loop *loop; - /** List of workers available to this node. This is used to free the worker - * structs when we free the scheduler state and also to access the worker - * structs in the tests. */ - std::list workers; - /** A set of driver IDs corresponding to drivers that have been removed. This - * is used to make sure we don't execute any tasks belong to dead drivers. */ - std::unordered_set removed_drivers; - /** A set of actors IDs corresponding to local actors that have been removed. - * This ensures we can reject any tasks destined for dead actors. */ - std::unordered_set removed_actors; - /** List of the process IDs for child processes (workers) started by the - * local scheduler that have not sent a REGISTER_PID message yet. */ - std::vector child_pids; - /** A hash table mapping actor IDs to the db_client_id of the local scheduler - * that is responsible for the actor. */ - std::unordered_map actor_mapping; - /** The handle to the database. */ - DBHandle *db; - /** The Plasma client. */ - plasma::PlasmaClient *plasma_conn; - /** State for the scheduling algorithm. */ - SchedulingAlgorithmState *algorithm_state; - /** Input buffer, used for reading input in process_message to avoid - * allocation for each call to process_message. */ - std::vector input_buffer; - /** Vector of static attributes associated with the node owned by this local - * scheduler. */ - std::unordered_map static_resources; - /** Vector of dynamic attributes associated with the node owned by this local - * scheduler. */ - std::unordered_map dynamic_resources; - /** The IDs of the available GPUs. There is redundancy here in that - * available_gpus.size() == dynamic_resources[ResourceIndex_GPU] should - * always be true. */ - std::vector available_gpus; - /** The time (in milliseconds since the Unix epoch) when the most recent - * heartbeat was sent. */ - int64_t previous_heartbeat_time; -}; - -/** Contains all information associated with a local scheduler client. */ -struct LocalSchedulerClient { - /** The socket used to communicate with the client. */ - int sock; - /** True if the client has registered and false otherwise. */ - bool registered; - /** True if the client has sent a disconnect message to the local scheduler - * and false otherwise. If this is true, then the local scheduler will not - * propagate an error message to the driver when the client exits. */ - bool disconnected; - /** True if the client is a worker and false if it is a driver. */ - bool is_worker; - /** The worker ID if the client is a worker and the driver ID if the client is - * a driver. */ - WorkerID client_id; - /** A pointer to the task object that is currently running on this client. If - * no task is running on the worker, this will be NULL. This is used to - * update the task table. */ - Task *task_in_progress; - /** An array of resource counts currently in use by the worker. */ - std::unordered_map resources_in_use; - /** A vector of the IDs of the GPUs that the worker is currently using. If the - * worker is an actor, this will be constant throughout the lifetime of the - * actor (and will be equal to the number of GPUs requested by the actor). If - * the worker is not an actor, this will be constant for the duration of a - * task and will have length equal to the number of GPUs requested by the - * task (in particular it will not change if the task blocks). */ - std::vector gpus_in_use; - /** A flag to indicate whether this worker is currently blocking on an - * object(s) that isn't available locally yet. */ - bool is_blocked; - /** The process ID of the client. If this is set to zero, the client has not - * yet registered a process ID. */ - pid_t pid; - /** Whether the client is a child process of the local scheduler. */ - bool is_child; - /** The ID of the actor on this worker. If there is no actor running on this - * worker, this should be NIL_ACTOR_ID. */ - ActorID actor_id; - /** A pointer to the local scheduler state. */ - LocalSchedulerState *local_scheduler_state; -}; - -/** - * Free the local scheduler state. This disconnects all clients and notifies - * the global scheduler of the local scheduler's exit. - * - * @param state The state to free. - * @return Void - */ -void LocalSchedulerState_free(LocalSchedulerState *state); - -#endif /* LOCAL_SCHEDULER_SHARED_H */ diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc deleted file mode 100644 index b155ea9494c84..0000000000000 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ /dev/null @@ -1,704 +0,0 @@ -#include "greatest.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "common.h" -#include "test/test_common.h" -#include "test/example_task.h" -#include "event_loop.h" -#include "io.h" -#include "task.h" -#include "state/object_table.h" -#include "state/task_table.h" -#include "state/redis.h" - -#include "local_scheduler_shared.h" -#include "local_scheduler.h" -#include "local_scheduler_algorithm.h" -#include "local_scheduler_client.h" - -SUITE(local_scheduler_tests); - -TaskBuilder *g_task_builder = NULL; - -const char *plasma_store_socket_name = "/tmp/plasma_store_socket_1"; -const char *plasma_manager_socket_name_format = "/tmp/plasma_manager_socket_%d"; -const char *local_scheduler_socket_name_format = - "/tmp/local_scheduler_socket_%d"; - -int64_t timeout_handler(event_loop *loop, int64_t id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -typedef struct { - /** A socket to mock the Plasma manager. Clients (such as workers) that - * connect to this file descriptor must be accepted. */ - int plasma_manager_fd; - /** A socket to communicate with the Plasma store. */ - int plasma_store_fd; - /** Local scheduler's socket for IPC requests. */ - int local_scheduler_fd; - /** Local scheduler's local scheduler state. */ - LocalSchedulerState *local_scheduler_state; - /** Local scheduler's event loop. */ - event_loop *loop; - /** Number of local scheduler client connections, or mock workers. */ - int num_local_scheduler_conns; - /** Local scheduler client connections. */ - LocalSchedulerConnection **conns; -} LocalSchedulerMock; - -/** - * Register clients of the local scheduler. This function is started in a - * separate thread so enable a blocking call to register the clients. - */ -static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) { - for (int i = 0; i < num_mock_workers; ++i) { - new_client_connection(mock->loop, mock->local_scheduler_fd, - (void *) mock->local_scheduler_state, 0); - LocalSchedulerClient *worker = mock->local_scheduler_state->workers.back(); - process_message(mock->local_scheduler_state->loop, worker->sock, worker, 0); - } -} - -LocalSchedulerMock *LocalSchedulerMock_init(int num_workers, - int num_mock_workers) { - const char *node_ip_address = "127.0.0.1"; - const char *redis_addr = node_ip_address; - int redis_port = 6379; - std::unordered_map static_resource_conf; - static_resource_conf["CPU"] = INT16_MAX; - static_resource_conf["GPU"] = 0; - LocalSchedulerMock *mock = - (LocalSchedulerMock *) malloc(sizeof(LocalSchedulerMock)); - memset(mock, 0, sizeof(LocalSchedulerMock)); - mock->loop = event_loop_create(); - /* Bind to the local scheduler port and initialize the local scheduler. */ - std::string plasma_manager_socket_name = bind_ipc_sock_retry( - plasma_manager_socket_name_format, &mock->plasma_manager_fd); - mock->plasma_store_fd = - connect_ipc_sock_retry(plasma_store_socket_name, 5, 100); - std::string local_scheduler_socket_name = bind_ipc_sock_retry( - local_scheduler_socket_name_format, &mock->local_scheduler_fd); - RAY_CHECK(mock->plasma_store_fd >= 0 && mock->local_scheduler_fd >= 0); - - /* Construct worker command */ - std::stringstream worker_command_ss; - worker_command_ss << "python ../python/ray/workers/default_worker.py" - << " --node-ip-address=" << node_ip_address - << " --object-store-name=" << plasma_store_socket_name - << " --object-store-manager-name=" - << plasma_manager_socket_name - << " --local-scheduler-name=" << local_scheduler_socket_name - << " --redis-address=" << redis_addr << ":" << redis_port; - std::string worker_command = worker_command_ss.str(); - - mock->local_scheduler_state = LocalSchedulerState_init( - "127.0.0.1", mock->loop, redis_addr, redis_port, - local_scheduler_socket_name.c_str(), plasma_store_socket_name, - plasma_manager_socket_name.c_str(), NULL, false, static_resource_conf, - worker_command.c_str(), num_workers); - - /* Accept the workers as clients to the plasma manager. */ - for (int i = 0; i < num_workers; ++i) { - accept_client(mock->plasma_manager_fd); - } - - /* Connect a local scheduler client. */ - mock->num_local_scheduler_conns = num_mock_workers; - mock->conns = (LocalSchedulerConnection **) malloc( - sizeof(LocalSchedulerConnection *) * num_mock_workers); - - std::thread background_thread = - std::thread(register_clients, num_mock_workers, mock); - - for (int i = 0; i < num_mock_workers; ++i) { - mock->conns[i] = LocalSchedulerConnection_init( - local_scheduler_socket_name.c_str(), WorkerID::nil(), true, - JobID::nil(), false, Language::PYTHON); - } - - background_thread.join(); - - return mock; -} - -void LocalSchedulerMock_free(LocalSchedulerMock *mock) { - /* Disconnect clients. */ - for (int i = 0; i < mock->num_local_scheduler_conns; ++i) { - LocalSchedulerConnection_free(mock->conns[i]); - } - free(mock->conns); - - /* Kill all the workers and run the event loop again so that the task table - * updates propagate and the tasks in progress are freed. */ - while (mock->local_scheduler_state->workers.size() > 0) { - LocalSchedulerClient *worker = mock->local_scheduler_state->workers.front(); - kill_worker(mock->local_scheduler_state, worker, true, false); - } - event_loop_add_timer(mock->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(mock->loop); - - /* This also frees mock->loop. */ - LocalSchedulerState_free(mock->local_scheduler_state); - close(mock->plasma_store_fd); - close(mock->plasma_manager_fd); - free(mock); -} - -void reset_worker(LocalSchedulerMock *mock, LocalSchedulerClient *worker) { - if (worker->task_in_progress) { - Task_free(worker->task_in_progress); - worker->task_in_progress = NULL; - } -} - -/** - * Test that object reconstruction gets called. If a task gets submitted, - * assigned to a worker, and then reconstruction is triggered for its return - * value, the task should get assigned to a worker again. - */ -TEST object_reconstruction_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - - /* Create a task with zero dependencies and one return value. */ - TaskExecutionSpec execution_spec = example_task_execution_spec(0, 1); - TaskSpec *spec = execution_spec.Spec(); - int64_t task_size = execution_spec.SpecSize(); - ObjectID return_id = TaskSpec_return(spec, 0); - - /* Add an empty object table entry for the object we want to reconstruct, to - * simulate it having been created and evicted. */ - const char *client_id = "clientid"; - /* Lookup the shard locations for the object table. */ - std::vector db_shards_addresses; - std::vector db_shards_ports; - redisContext *context = redisConnect("127.0.0.1", 6379); - get_redis_shards(context, db_shards_addresses, db_shards_ports); - redisFree(context); - /* There should only be one shard, so we can safely add the empty object - * table entry to the first one. */ - ASSERT(db_shards_addresses.size() == 1); - context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); - redisReply *reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.data(), - sizeof(return_id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); - freeReplyObject(reply); - reply = (redisReply *) redisCommand(context, "RAY.OBJECT_TABLE_REMOVE %b %s", - return_id.data(), sizeof(return_id), - client_id); - freeReplyObject(reply); - redisFree(context); - - pid_t pid = fork(); - if (pid == 0) { - /* Make sure we receive the task twice. First from the initial submission, - * and second from the reconstruct request. */ - int64_t task_assigned_size; - local_scheduler_submit(worker, execution_spec); - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - ASSERT_EQ(memcmp(task_assigned, spec, task_size), 0); - ASSERT_EQ(task_assigned_size, task_size); - int64_t reconstruct_task_size; - TaskSpec *reconstruct_task = - local_scheduler_get_task(worker, &reconstruct_task_size); - ASSERT_EQ(memcmp(reconstruct_task, spec, task_size), 0); - ASSERT_EQ(reconstruct_task_size, task_size); - /* Clean up. */ - free(reconstruct_task); - free(task_assigned); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Run the event loop. NOTE: OSX appears to require the parent process to - * listen for events on the open file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Set the task's status to TaskStatus::DONE to prevent the race condition - * that would suppress object reconstruction. */ - Task *task = Task_alloc( - execution_spec, TaskStatus::DONE, - get_db_client_id(local_scheduler->local_scheduler_state->db)); - task_table_add_task(local_scheduler->local_scheduler_state->db, task, NULL, - NULL, NULL); - - /* Trigger reconstruction, and run the event loop again. */ - ObjectID return_id = TaskSpec_return(spec, 0); - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -/** - * Test that object reconstruction gets recursively called. In a chain of - * tasks, if all inputs are lost, then reconstruction of the final object - * should trigger reconstruction of all previous tasks in the lineage. - */ -TEST object_reconstruction_recursive_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - /* Create a chain of tasks, each one dependent on the one before it. Mark - * each object as available so that tasks will run immediately. */ - const int NUM_TASKS = 10; - std::vector specs; - specs.push_back(example_task_execution_spec(0, 1)); - for (int i = 1; i < NUM_TASKS; ++i) { - ObjectID arg_id = TaskSpec_return(specs[i - 1].Spec(), 0); - specs.push_back(example_task_execution_spec_with_args(1, 1, &arg_id)); - } - /* Lookup the shard locations for the object table. */ - const char *client_id = "clientid"; - std::vector db_shards_addresses; - std::vector db_shards_ports; - redisContext *context = redisConnect("127.0.0.1", 6379); - get_redis_shards(context, db_shards_addresses, db_shards_ports); - redisFree(context); - /* There should only be one shard, so we can safely add the empty object - * table entry to the first one. */ - ASSERT(db_shards_addresses.size() == 1); - context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - redisReply *reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.data(), - sizeof(return_id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); - freeReplyObject(reply); - reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_REMOVE %b %s", return_id.data(), - sizeof(return_id), client_id); - freeReplyObject(reply); - } - redisFree(context); - - pid_t pid = fork(); - if (pid == 0) { - /* Submit the tasks, and make sure each one gets assigned to a worker. */ - for (int i = 0; i < NUM_TASKS; ++i) { - local_scheduler_submit(worker, specs[i]); - } - /* Make sure we receive each task from the initial submission. */ - for (int i = 0; i < NUM_TASKS; ++i) { - int64_t task_size; - TaskSpec *task_assigned = local_scheduler_get_task(worker, &task_size); - ASSERT_EQ(memcmp(task_assigned, specs[i].Spec(), specs[i].SpecSize()), 0); - ASSERT_EQ(task_size, specs[i].SpecSize()); - free(task_assigned); - } - /* Check that the workers receive all tasks in the final return object's - * lineage during reconstruction. */ - for (int i = 0; i < NUM_TASKS; ++i) { - int64_t task_assigned_size; - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - for (auto it = specs.begin(); it != specs.end(); it++) { - if (memcmp(task_assigned, it->Spec(), task_assigned_size) == 0) { - specs.erase(it); - break; - } - } - free(task_assigned); - } - ASSERT(specs.size() == 0); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Simulate each task putting its return values in the object store so that - * the next task can run. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_available( - local_scheduler->local_scheduler_state, - local_scheduler->local_scheduler_state->algorithm_state, return_id); - } - /* Run the event loop. All tasks should now be dispatched. NOTE: OSX - * appears to require the parent process to listen for events on the open - * file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Set the final task's status to TaskStatus::DONE to prevent the race - * condition that would suppress object reconstruction. */ - Task *last_task = Task_alloc( - specs[NUM_TASKS - 1], TaskStatus::DONE, - get_db_client_id(local_scheduler->local_scheduler_state->db)); - task_table_add_task(local_scheduler->local_scheduler_state->db, last_task, - NULL, NULL, NULL); - /* Simulate eviction of the objects, so that reconstruction is required. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_removed(local_scheduler->local_scheduler_state, return_id); - } - /* Trigger reconstruction for the last object. */ - ObjectID return_id = TaskSpec_return(specs[NUM_TASKS - 1].Spec(), 0); - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - /* Run the event loop again. All tasks should be resubmitted. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Simulate each task putting its return values in the object store so that - * the next task can run. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_available( - local_scheduler->local_scheduler_state, - local_scheduler->local_scheduler_state->algorithm_state, return_id); - } - /* Run the event loop again. All tasks should be dispatched again. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - specs.clear(); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -/** - * Test that object reconstruction gets suppressed when there is a location - * listed for the object in the object table. - */ -TaskExecutionSpec *object_reconstruction_suppression_spec; - -void object_reconstruction_suppression_callback(ObjectID object_id, - bool success, - void *user_context) { - RAY_CHECK(success); - /* Submit the task after adding the object to the object table. */ - LocalSchedulerConnection *worker = (LocalSchedulerConnection *) user_context; - local_scheduler_submit(worker, *object_reconstruction_suppression_spec); -} - -TEST object_reconstruction_suppression_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - - TaskExecutionSpec execution_spec = example_task_execution_spec(0, 1); - object_reconstruction_suppression_spec = &execution_spec; - ObjectID return_id = - TaskSpec_return(object_reconstruction_suppression_spec->Spec(), 0); - pid_t pid = fork(); - if (pid == 0) { - /* Make sure we receive the task once. This will block until the - * object_table_add callback completes. */ - int64_t task_assigned_size; - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - ASSERT_EQ( - memcmp(task_assigned, object_reconstruction_suppression_spec->Spec(), - object_reconstruction_suppression_spec->SpecSize()), - 0); - /* Trigger a reconstruction. We will check that no tasks get queued as a - * result of this line in the event loop process. */ - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - /* Clean up. */ - free(task_assigned); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Connect a plasma manager client so we can call object_table_add. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:12346"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, local_scheduler->loop, false); - /* Add the object to the object table. */ - object_table_add(db, return_id, 1, (unsigned char *) NIL_DIGEST, NULL, - object_reconstruction_suppression_callback, - (void *) worker); - /* Run the event loop. NOTE: OSX appears to require the parent process to - * listen for events on the open file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 1000, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - db_disconnect(db); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -TEST task_dependency_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerState *state = local_scheduler->local_scheduler_state; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - /* Get the first worker. */ - LocalSchedulerClient *worker = state->workers.front(); - TaskExecutionSpec execution_spec = example_task_execution_spec(1, 1); - TaskSpec *spec = execution_spec.Spec(); - ObjectID oid = TaskSpec_arg_id(spec, 0, 0); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted, but the input and workers are not available. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted and a worker is available, but the input is not. */ - handle_object_removed(state, oid); - handle_task_submitted(state, algorithm_state, execution_spec); - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets assigned. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the dispatch queue if the task is - * submitted and the input is available, but no worker is available yet. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* If an object gets removed, check the first scenario again, where the task - * gets queued in the waiting task if the task is submitted and a worker is - * available, but the input is not. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* If the input is removed while a task is in the dispatch queue, the task - * gets moved back to the waiting queue. */ - handle_object_removed(state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets moved back to the dispatch - * queue. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -TEST task_multi_dependency_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerState *state = local_scheduler->local_scheduler_state; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - /* Get the first worker. */ - LocalSchedulerClient *worker = state->workers.front(); - TaskExecutionSpec execution_spec = example_task_execution_spec(2, 1); - TaskSpec *spec = execution_spec.Spec(); - ObjectID oid1 = TaskSpec_arg_id(spec, 0, 0); - ObjectID oid2 = TaskSpec_arg_id(spec, 1, 0); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted, but the inputs and workers are not available. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if only one input becomes - * available. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once all inputs are available, the task is moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the dispatch queue if the task is - * submitted and the inputs are available, but no worker is available yet. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* If any input is removed while a task is in the dispatch queue, the task - * gets moved back to the waiting queue. */ - handle_object_removed(state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - handle_object_removed(state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if only one input becomes - * available. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if the one input is - * unavailable again. */ - handle_object_removed(state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if the other input becomes - * available. */ - handle_object_available(state, algorithm_state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once all inputs are available, the task is moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -TEST start_kill_workers_test(void) { - /* Start some workers. */ - int num_workers = 4; - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(num_workers, 0); - /* We start off with num_workers children processes, but no workers - * registered yet. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), - static_cast(num_workers)); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), 0); - - /* Make sure that each worker connects to the local_scheduler scheduler. This - * for loop will hang if one of the workers does not connect. */ - for (int i = 0; i < num_workers; ++i) { - new_client_connection(local_scheduler->loop, - local_scheduler->local_scheduler_fd, - (void *) local_scheduler->local_scheduler_state, 0); - } - - /* After handling each worker's initial connection, we should now have all - * workers accounted for, but we haven't yet matched up process IDs with our - * children processes. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), - static_cast(num_workers)); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* Each worker should register its process ID. */ - for (auto const &worker : local_scheduler->local_scheduler_state->workers) { - process_message(local_scheduler->local_scheduler_state->loop, worker->sock, - worker, 0); - } - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* After killing a worker, its state is cleaned up. */ - LocalSchedulerClient *worker = - local_scheduler->local_scheduler_state->workers.front(); - kill_worker(local_scheduler->local_scheduler_state, worker, false, false); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers - 1)); - - /* Start a worker after the local scheduler has been initialized. */ - start_worker(local_scheduler->local_scheduler_state); - /* Accept the workers as clients to the plasma manager. */ - int new_worker_fd = accept_client(local_scheduler->plasma_manager_fd); - /* The new worker should register its process ID. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 1); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers - 1)); - /* Make sure the new worker connects to the local_scheduler scheduler. */ - new_client_connection(local_scheduler->loop, - local_scheduler->local_scheduler_fd, - (void *) local_scheduler->local_scheduler_state, 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 1); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - /* Make sure that the new worker registers its process ID. */ - worker = local_scheduler->local_scheduler_state->workers.back(); - process_message(local_scheduler->local_scheduler_state->loop, worker->sock, - worker, 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* Clean up. */ - close(new_worker_fd); - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -SUITE(local_scheduler_tests) { - RUN_REDIS_TEST(object_reconstruction_test); - RUN_REDIS_TEST(object_reconstruction_recursive_test); - RUN_REDIS_TEST(object_reconstruction_suppression_test); - RUN_REDIS_TEST(task_dependency_test); - RUN_REDIS_TEST(task_multi_dependency_test); - RUN_REDIS_TEST(start_kill_workers_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(local_scheduler_tests); - GREATEST_MAIN_END(); -} diff --git a/src/local_scheduler/test/run_tests.sh b/src/local_scheduler/test/run_tests.sh deleted file mode 100644 index 9c1d7be79b788..0000000000000 --- a/src/local_scheduler/test/run_tests.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -# Cause the script to exit if a single command fails. -set -e - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -sleep 1s -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -./src/plasma/plasma_store_server -s /tmp/plasma_store_socket_1 -m 100000000 & -sleep 0.5s -./src/local_scheduler/local_scheduler_tests -./src/common/thirdparty/redis/src/redis-cli shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -killall plasma_store_server diff --git a/src/local_scheduler/test/run_valgrind.sh b/src/local_scheduler/test/run_valgrind.sh deleted file mode 100644 index 6ff1dbe33c628..0000000000000 --- a/src/local_scheduler/test/run_valgrind.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -set -x - -# Cause the script to exit if a single command fails. -set -e - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -sleep 1s - -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -./src/plasma/plasma_store_server -s /tmp/plasma_store_socket_1 -m 100000000 & -sleep 0.5s -valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/local_scheduler/local_scheduler_tests -./src/common/thirdparty/redis/src/redis-cli shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -killall plasma_store_server diff --git a/src/plasma/CMakeLists.txt b/src/plasma/CMakeLists.txt deleted file mode 100644 index 5037a54da3d71..0000000000000 --- a/src/plasma/CMakeLists.txt +++ /dev/null @@ -1,61 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(plasma) - -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${CMAKE_CURRENT_LIST_DIR}/thirdparty) - -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --std=c99 -O3") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11 -O3 -Werror -Wall") - -if(UNIX AND NOT APPLE) - link_libraries(rt) -endif() - -include_directories("${ARROW_INCLUDE_DIR}") - -set(PLASMA_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/plasma.fbs" "${CMAKE_CURRENT_LIST_DIR}/format/common.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(PLASMA_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/plasma_generated.h" - "${OUTPUT_DIR}/common_generated.h") - -add_custom_target(gen_plasma_fbs DEPENDS ${PLASMA_FBS_OUTPUT_FILES}) -add_dependencies(gen_plasma_fbs arrow_ep) - -# Copy the fbs files from Arrow project to local directory. -add_custom_command( - OUTPUT ${PLASMA_FBS_SRC} - COMMAND mkdir -p ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMAND cp ${ARROW_SOURCE_DIR}/cpp/src/plasma/format/plasma.fbs ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMAND cp ${ARROW_SOURCE_DIR}/cpp/src/plasma/format/common.fbs ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMENT "Copying ${PLASMA_FBS_SRC} to local" - VERBATIM) - -# Compile flatbuffers -add_custom_command( - OUTPUT ${PLASMA_FBS_OUTPUT_FILES} - # The --gen-object-api flag generates a C++ class MessageT for each - # flatbuffers message Message, which can be used to store deserialized - # messages in data structures. This is currently used for ObjectInfo for - # example. - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${PLASMA_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${PLASMA_FBS_SRC} - COMMENT "Running flatc compiler on ${PLASMA_FBS_SRC}" - VERBATIM) - -include_directories("${FLATBUFFERS_INCLUDE_DIR}") - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") - -add_executable(plasma_manager - plasma_manager.cc) -add_dependencies(plasma_manager gen_plasma_fbs) - -target_link_libraries(plasma_manager common ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) - -define_test(client_tests "") -define_test(manager_tests "" plasma_manager.cc) -target_link_libraries(manager_tests ${Boost_SYSTEM_LIBRARY}) -add_dependencies(manager_tests gen_plasma_fbs) diff --git a/src/plasma/doc/plasma-doxy-config b/src/plasma/doc/plasma-doxy-config deleted file mode 100644 index 9c291f8388833..0000000000000 --- a/src/plasma/doc/plasma-doxy-config +++ /dev/null @@ -1,2473 +0,0 @@ -# Doxyfile 1.8.13 - -# This file describes the settings to be used by the documentation system -# doxygen (www.doxygen.org) for a project. -# -# All text after a double hash (##) is considered a comment and is placed in -# front of the TAG it is preceding. -# -# All text after a single hash (#) is considered a comment and will be ignored. -# The format is: -# TAG = value [value, ...] -# For lists, items can also be appended using: -# TAG += value [value, ...] -# Values that contain spaces should be placed between quotes (\" \"). - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- - -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. -# The default value is: UTF-8. - -DOXYFILE_ENCODING = UTF-8 - -# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by -# double-quotes, unless you are using Doxywizard) that should identify the -# project for which the documentation is generated. This name is used in the -# title of most generated pages and in a few other places. -# The default value is: My Project. - -PROJECT_NAME = "Plasma" - -# The PROJECT_NUMBER tag can be used to enter a project or revision number. This -# could be handy for archiving the generated documentation or if some version -# control system is used. - -PROJECT_NUMBER = - -# Using the PROJECT_BRIEF tag one can provide an optional one line description -# for a project that appears at the top of each page and should give viewer a -# quick idea about the purpose of the project. Keep the description short. - -PROJECT_BRIEF = - -# With the PROJECT_LOGO tag one can specify a logo or an icon that is included -# in the documentation. The maximum height of the logo should not exceed 55 -# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy -# the logo to the output directory. - -PROJECT_LOGO = - -# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path -# into which the generated documentation will be written. If a relative path is -# entered, it will be relative to the location where doxygen was started. If -# left blank the current directory will be used. - -OUTPUT_DIRECTORY = - -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this -# option can be useful when feeding doxygen a huge amount of source files, where -# putting all generated files in the same directory would otherwise causes -# performance problems for the file system. -# The default value is: NO. - -CREATE_SUBDIRS = NO - -# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII -# characters to appear in the names of generated files. If set to NO, non-ASCII -# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode -# U+3044. -# The default value is: NO. - -ALLOW_UNICODE_NAMES = NO - -# The OUTPUT_LANGUAGE tag is used to specify the language in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. -# The default value is: English. - -OUTPUT_LANGUAGE = English - -# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member -# descriptions after the members that are listed in the file and class -# documentation (similar to Javadoc). Set to NO to disable this. -# The default value is: YES. - -BRIEF_MEMBER_DESC = YES - -# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief -# description of a member or function before the detailed description -# -# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the -# brief descriptions will be completely suppressed. -# The default value is: YES. - -REPEAT_BRIEF = YES - -# This tag implements a quasi-intelligent brief description abbreviator that is -# used to form the text in various listings. Each string in this list, if found -# as the leading text of the brief description, will be stripped from the text -# and the result, after processing the whole list, is used as the annotated -# text. Otherwise, the brief description is used as-is. If left blank, the -# following values are used ($name is automatically replaced with the name of -# the entity):The $name class, The $name widget, The $name file, is, provides, -# specifies, contains, represents, a, an and the. - -ABBREVIATE_BRIEF = "The $name class" \ - "The $name widget" \ - "The $name file" \ - is \ - provides \ - specifies \ - contains \ - represents \ - a \ - an \ - the - -# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then -# doxygen will generate a detailed section even if there is only a brief -# description. -# The default value is: NO. - -ALWAYS_DETAILED_SEC = NO - -# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all -# inherited members of a class in the documentation of that class as if those -# members were ordinary class members. Constructors, destructors and assignment -# operators of the base classes will not be shown. -# The default value is: NO. - -INLINE_INHERITED_MEMB = NO - -# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path -# before files name in the file list and in the header files. If set to NO the -# shortest path that makes the file name unique will be used -# The default value is: YES. - -FULL_PATH_NAMES = YES - -# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. -# Stripping is only done if one of the specified strings matches the left-hand -# part of the path. The tag can be used to show relative paths in the file list. -# If left blank the directory from which doxygen is run is used as the path to -# strip. -# -# Note that you can specify absolute paths here, but also relative paths, which -# will be relative from the directory where doxygen is started. -# This tag requires that the tag FULL_PATH_NAMES is set to YES. - -STRIP_FROM_PATH = - -# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the -# path mentioned in the documentation of a class, which tells the reader which -# header file to include in order to use a class. If left blank only the name of -# the header file containing the class definition is used. Otherwise one should -# specify the list of include paths that are normally passed to the compiler -# using the -I flag. - -STRIP_FROM_INC_PATH = - -# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but -# less readable) file names. This can be useful is your file systems doesn't -# support long names like on DOS, Mac, or CD-ROM. -# The default value is: NO. - -SHORT_NAMES = NO - -# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the -# first line (until the first dot) of a Javadoc-style comment as the brief -# description. If set to NO, the Javadoc-style will behave just like regular Qt- -# style comments (thus requiring an explicit @brief command for a brief -# description.) -# The default value is: NO. - -JAVADOC_AUTOBRIEF = NO - -# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first -# line (until the first dot) of a Qt-style comment as the brief description. If -# set to NO, the Qt-style will behave just like regular Qt-style comments (thus -# requiring an explicit \brief command for a brief description.) -# The default value is: NO. - -QT_AUTOBRIEF = NO - -# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a -# multi-line C++ special comment block (i.e. a block of //! or /// comments) as -# a brief description. This used to be the default behavior. The new default is -# to treat a multi-line C++ comment block as a detailed description. Set this -# tag to YES if you prefer the old behavior instead. -# -# Note that setting this tag to YES also means that rational rose comments are -# not recognized any more. -# The default value is: NO. - -MULTILINE_CPP_IS_BRIEF = NO - -# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the -# documentation from any documented member that it re-implements. -# The default value is: YES. - -INHERIT_DOCS = YES - -# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new -# page for each member. If set to NO, the documentation of a member will be part -# of the file/class/namespace that contains it. -# The default value is: NO. - -SEPARATE_MEMBER_PAGES = NO - -# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen -# uses this value to replace tabs by spaces in code fragments. -# Minimum value: 1, maximum value: 16, default value: 4. - -TAB_SIZE = 2 - -# This tag can be used to specify a number of aliases that act as commands in -# the documentation. An alias has the form: -# name=value -# For example adding -# "sideeffect=@par Side Effects:\n" -# will allow you to put the command \sideeffect (or @sideeffect) in the -# documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. - -ALIASES = - -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - -# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources -# only. Doxygen will then generate output that is more tailored for C. For -# instance, some of the names that are used will be different. The list of all -# members will be omitted, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_FOR_C = NO - -# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or -# Python sources only. Doxygen will then generate output that is more tailored -# for that language. For instance, namespaces will be presented as packages, -# qualified scopes will look different, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_JAVA = NO - -# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran -# sources. Doxygen will then generate output that is tailored for Fortran. -# The default value is: NO. - -OPTIMIZE_FOR_FORTRAN = NO - -# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL -# sources. Doxygen will then generate output that is tailored for VHDL. -# The default value is: NO. - -OPTIMIZE_OUTPUT_VHDL = NO - -# Doxygen selects the parser to use depending on the extension of the files it -# parses. With this tag you can assign which parser to use for a given -# extension. Doxygen has a built-in mapping, but you can override or extend it -# using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. -# -# Note: For files without extension you can use no_extension as a placeholder. -# -# Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. - -EXTENSION_MAPPING = - -# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments -# according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. -# The output of markdown processing is further processed by doxygen, so you can -# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in -# case of backward compatibilities issues. -# The default value is: YES. - -MARKDOWN_SUPPORT = YES - -# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up -# to that level are automatically included in the table of contents, even if -# they do not have an id attribute. -# Note: This feature currently applies only to Markdown headings. -# Minimum value: 0, maximum value: 99, default value: 0. -# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. - -TOC_INCLUDE_HEADINGS = 0 - -# When enabled doxygen tries to link words that correspond to documented -# classes, or namespaces to their corresponding documentation. Such a link can -# be prevented in individual cases by putting a % sign in front of the word or -# globally by setting AUTOLINK_SUPPORT to NO. -# The default value is: YES. - -AUTOLINK_SUPPORT = YES - -# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want -# to include (a tag file for) the STL sources as input, then you should set this -# tag to YES in order to let doxygen match functions declarations and -# definitions whose arguments contain STL classes (e.g. func(std::string); -# versus func(std::string) {}). This also make the inheritance and collaboration -# diagrams that involve STL classes more complete and accurate. -# The default value is: NO. - -BUILTIN_STL_SUPPORT = NO - -# If you use Microsoft's C++/CLI language, you should set this option to YES to -# enable parsing support. -# The default value is: NO. - -CPP_CLI_SUPPORT = NO - -# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen -# will parse them like normal C++ but will assume all classes use public instead -# of private inheritance when no explicit protection keyword is present. -# The default value is: NO. - -SIP_SUPPORT = NO - -# For Microsoft's IDL there are propget and propput attributes to indicate -# getter and setter methods for a property. Setting this option to YES will make -# doxygen to replace the get and set methods by a property in the documentation. -# This will only work if the methods are indeed getting or setting a simple -# type. If this is not the case, or you want to show the methods anyway, you -# should set this option to NO. -# The default value is: YES. - -IDL_PROPERTY_SUPPORT = YES - -# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC -# tag is set to YES then doxygen will reuse the documentation of the first -# member in the group (if any) for the other members of the group. By default -# all members of a group must be documented explicitly. -# The default value is: NO. - -DISTRIBUTE_GROUP_DOC = NO - -# If one adds a struct or class to a group and this option is enabled, then also -# any nested class or struct is added to the same group. By default this option -# is disabled and one has to add nested compounds explicitly via \ingroup. -# The default value is: NO. - -GROUP_NESTED_COMPOUNDS = NO - -# Set the SUBGROUPING tag to YES to allow class member groups of the same type -# (for instance a group of public functions) to be put as a subgroup of that -# type (e.g. under the Public Functions section). Set it to NO to prevent -# subgrouping. Alternatively, this can be done per class using the -# \nosubgrouping command. -# The default value is: YES. - -SUBGROUPING = YES - -# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions -# are shown inside the group in which they are included (e.g. using \ingroup) -# instead of on a separate page (for HTML and Man pages) or section (for LaTeX -# and RTF). -# -# Note that this feature does not work in combination with -# SEPARATE_MEMBER_PAGES. -# The default value is: NO. - -INLINE_GROUPED_CLASSES = NO - -# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions -# with only public data fields or simple typedef fields will be shown inline in -# the documentation of the scope in which they are defined (i.e. file, -# namespace, or group documentation), provided this scope is documented. If set -# to NO, structs, classes, and unions are shown on a separate page (for HTML and -# Man pages) or section (for LaTeX and RTF). -# The default value is: NO. - -INLINE_SIMPLE_STRUCTS = NO - -# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or -# enum is documented as struct, union, or enum with the name of the typedef. So -# typedef struct TypeS {} TypeT, will appear in the documentation as a struct -# with name TypeT. When disabled the typedef will appear as a member of a file, -# namespace, or class. And the struct will be named TypeS. This can typically be -# useful for C code in case the coding convention dictates that all compound -# types are typedef'ed and only the typedef is referenced, never the tag name. -# The default value is: NO. - -TYPEDEF_HIDES_STRUCT = NO - -# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This -# cache is used to resolve symbols given their name and scope. Since this can be -# an expensive process and often the same symbol appears multiple times in the -# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small -# doxygen will become slower. If the cache is too large, memory is wasted. The -# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range -# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 -# symbols. At the end of a run doxygen will report the cache usage and suggest -# the optimal cache size from a speed point of view. -# Minimum value: 0, maximum value: 9, default value: 0. - -LOOKUP_CACHE_SIZE = 0 - -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- - -# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in -# documentation are documented, even if no documentation was available. Private -# class members and static file members will be hidden unless the -# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. -# Note: This will also disable the warnings about undocumented members that are -# normally produced when WARNINGS is set to YES. -# The default value is: NO. - -EXTRACT_ALL = YES - -# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will -# be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIVATE = NO - -# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal -# scope will be included in the documentation. -# The default value is: NO. - -EXTRACT_PACKAGE = NO - -# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be -# included in the documentation. -# The default value is: NO. - -EXTRACT_STATIC = NO - -# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined -# locally in source files will be included in the documentation. If set to NO, -# only classes defined in header files are included. Does not have any effect -# for Java sources. -# The default value is: YES. - -EXTRACT_LOCAL_CLASSES = YES - -# This flag is only useful for Objective-C code. If set to YES, local methods, -# which are defined in the implementation section but not in the interface are -# included in the documentation. If set to NO, only methods in the interface are -# included. -# The default value is: NO. - -EXTRACT_LOCAL_METHODS = NO - -# If this flag is set to YES, the members of anonymous namespaces will be -# extracted and appear in the documentation as a namespace called -# 'anonymous_namespace{file}', where file will be replaced with the base name of -# the file that contains the anonymous namespace. By default anonymous namespace -# are hidden. -# The default value is: NO. - -EXTRACT_ANON_NSPACES = NO - -# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all -# undocumented members inside documented classes or files. If set to NO these -# members will be included in the various overviews, but no documentation -# section is generated. This option has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_MEMBERS = NO - -# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all -# undocumented classes that are normally visible in the class hierarchy. If set -# to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_CLASSES = NO - -# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. -# The default value is: NO. - -HIDE_FRIEND_COMPOUNDS = NO - -# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any -# documentation blocks found inside the body of a function. If set to NO, these -# blocks will be appended to the function's detailed documentation block. -# The default value is: NO. - -HIDE_IN_BODY_DOCS = NO - -# The INTERNAL_DOCS tag determines if documentation that is typed after a -# \internal command is included. If the tag is set to NO then the documentation -# will be excluded. Set it to YES to include the internal documentation. -# The default value is: NO. - -INTERNAL_DOCS = NO - -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. - -CASE_SENSE_NAMES = NO - -# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with -# their full class and namespace scopes in the documentation. If set to YES, the -# scope will be hidden. -# The default value is: NO. - -HIDE_SCOPE_NAMES = NO - -# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will -# append additional text to a page's title, such as Class Reference. If set to -# YES the compound reference will be hidden. -# The default value is: NO. - -HIDE_COMPOUND_REFERENCE= NO - -# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of -# the files that are included by a file in the documentation of that file. -# The default value is: YES. - -SHOW_INCLUDE_FILES = YES - -# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each -# grouped member an include statement to the documentation, telling the reader -# which file to include in order to use the member. -# The default value is: NO. - -SHOW_GROUPED_MEMB_INC = NO - -# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include -# files with double quotes in the documentation rather than with sharp brackets. -# The default value is: NO. - -FORCE_LOCAL_INCLUDES = NO - -# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the -# documentation for inline members. -# The default value is: YES. - -INLINE_INFO = YES - -# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the -# (detailed) documentation of file and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. -# The default value is: YES. - -SORT_MEMBER_DOCS = YES - -# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief -# descriptions of file, namespace and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. Note that -# this will also influence the order of the classes in the class list. -# The default value is: NO. - -SORT_BRIEF_DOCS = NO - -# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the -# (brief and detailed) documentation of class members so that constructors and -# destructors are listed first. If set to NO the constructors will appear in the -# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. -# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief -# member documentation. -# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting -# detailed member documentation. -# The default value is: NO. - -SORT_MEMBERS_CTORS_1ST = NO - -# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy -# of group names into alphabetical order. If set to NO the group names will -# appear in their defined order. -# The default value is: NO. - -SORT_GROUP_NAMES = NO - -# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by -# fully-qualified names, including namespaces. If set to NO, the class list will -# be sorted only by class name, not including the namespace part. -# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. -# Note: This option applies only to the class list, not to the alphabetical -# list. -# The default value is: NO. - -SORT_BY_SCOPE_NAME = NO - -# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper -# type resolution of all parameters of a function it will reject a match between -# the prototype and the implementation of a member function even if there is -# only one candidate or it is obvious which candidate to choose by doing a -# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still -# accept a match between prototype and implementation in such cases. -# The default value is: NO. - -STRICT_PROTO_MATCHING = NO - -# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo -# list. This list is created by putting \todo commands in the documentation. -# The default value is: YES. - -GENERATE_TODOLIST = YES - -# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test -# list. This list is created by putting \test commands in the documentation. -# The default value is: YES. - -GENERATE_TESTLIST = YES - -# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug -# list. This list is created by putting \bug commands in the documentation. -# The default value is: YES. - -GENERATE_BUGLIST = YES - -# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) -# the deprecated list. This list is created by putting \deprecated commands in -# the documentation. -# The default value is: YES. - -GENERATE_DEPRECATEDLIST= YES - -# The ENABLED_SECTIONS tag can be used to enable conditional documentation -# sections, marked by \if ... \endif and \cond -# ... \endcond blocks. - -ENABLED_SECTIONS = - -# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the -# initial value of a variable or macro / define can have for it to appear in the -# documentation. If the initializer consists of more lines than specified here -# it will be hidden. Use a value of 0 to hide initializers completely. The -# appearance of the value of individual variables and macros / defines can be -# controlled using \showinitializer or \hideinitializer command in the -# documentation regardless of this setting. -# Minimum value: 0, maximum value: 10000, default value: 30. - -MAX_INITIALIZER_LINES = 30 - -# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at -# the bottom of the documentation of classes and structs. If set to YES, the -# list will mention the files that were used to generate the documentation. -# The default value is: YES. - -SHOW_USED_FILES = YES - -# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This -# will remove the Files entry from the Quick Index and from the Folder Tree View -# (if specified). -# The default value is: YES. - -SHOW_FILES = YES - -# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces -# page. This will remove the Namespaces entry from the Quick Index and from the -# Folder Tree View (if specified). -# The default value is: YES. - -SHOW_NAMESPACES = YES - -# The FILE_VERSION_FILTER tag can be used to specify a program or script that -# doxygen should invoke to get the current version for each file (typically from -# the version control system). Doxygen will invoke the program by executing (via -# popen()) the command command input-file, where command is the value of the -# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided -# by doxygen. Whatever the program writes to standard output is used as the file -# version. For an example see the documentation. - -FILE_VERSION_FILTER = - -# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed -# by doxygen. The layout file controls the global structure of the generated -# output files in an output format independent way. To create the layout file -# that represents doxygen's defaults, run doxygen with the -l option. You can -# optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. -# -# Note that if you run doxygen from a directory containing a file called -# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE -# tag is left empty. - -LAYOUT_FILE = - -# The CITE_BIB_FILES tag can be used to specify one or more bib files containing -# the reference definitions. This must be a list of .bib files. The .bib -# extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. -# For LaTeX the style of the bibliography can be controlled using -# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the -# search path. See also \cite for info how to create references. - -CITE_BIB_FILES = - -#--------------------------------------------------------------------------- -# Configuration options related to warning and progress messages -#--------------------------------------------------------------------------- - -# The QUIET tag can be used to turn on/off the messages that are generated to -# standard output by doxygen. If QUIET is set to YES this implies that the -# messages are off. -# The default value is: NO. - -QUIET = NO - -# The WARNINGS tag can be used to turn on/off the warning messages that are -# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES -# this implies that the warnings are on. -# -# Tip: Turn warnings on while writing the documentation. -# The default value is: YES. - -WARNINGS = YES - -# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate -# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: YES. - -WARN_IF_UNDOCUMENTED = YES - -# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. -# The default value is: YES. - -WARN_IF_DOC_ERROR = YES - -# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that -# are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. -# The default value is: NO. - -WARN_NO_PARAMDOC = NO - -# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when -# a warning is encountered. -# The default value is: NO. - -WARN_AS_ERROR = NO - -# The WARN_FORMAT tag determines the format of the warning messages that doxygen -# can produce. The string should contain the $file, $line, and $text tags, which -# will be replaced by the file and line number from which the warning originated -# and the warning text. Optionally the format may contain $version, which will -# be replaced by the version of the file (if it could be obtained via -# FILE_VERSION_FILTER) -# The default value is: $file:$line: $text. - -WARN_FORMAT = "$file:$line: $text" - -# The WARN_LOGFILE tag can be used to specify a file to which warning and error -# messages should be written. If left blank the output is written to standard -# error (stderr). - -WARN_LOGFILE = - -#--------------------------------------------------------------------------- -# Configuration options related to the input files -#--------------------------------------------------------------------------- - -# The INPUT tag is used to specify the files and/or directories that contain -# documented source files. You may enter file names like myfile.cpp or -# directories like /usr/src/myproject. Separate the files or directories with -# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING -# Note: If this tag is empty the current directory is searched. - -INPUT = ../src - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses -# libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. -# The default value is: UTF-8. - -INPUT_ENCODING = UTF-8 - -# If the value of the INPUT tag contains directories, you can use the -# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and -# *.h) to filter out the source-files in the directories. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# read by doxygen. -# -# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, -# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, -# *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf and *.qsf. - -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.c++ \ - *.java \ - *.ii \ - *.ixx \ - *.ipp \ - *.i++ \ - *.inl \ - *.idl \ - *.ddl \ - *.odl \ - *.h \ - *.hh \ - *.hxx \ - *.hpp \ - *.h++ \ - *.cs \ - *.d \ - *.php \ - *.php4 \ - *.php5 \ - *.phtml \ - *.inc \ - *.m \ - *.markdown \ - *.md \ - *.mm \ - *.dox \ - *.py \ - *.pyw \ - *.f90 \ - *.f95 \ - *.f03 \ - *.f08 \ - *.f \ - *.for \ - *.tcl \ - *.vhd \ - *.vhdl \ - *.ucf \ - *.qsf - -# The RECURSIVE tag can be used to specify whether or not subdirectories should -# be searched for input files as well. -# The default value is: NO. - -RECURSIVE = NO - -# The EXCLUDE tag can be used to specify files and/or directories that should be -# excluded from the INPUT source files. This way you can easily exclude a -# subdirectory from a directory tree whose root is specified with the INPUT tag. -# -# Note that relative paths are relative to the directory from which doxygen is -# run. - -EXCLUDE = ../src/utarray.h ../src/uthash.h - -# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or -# directories that are symbolic links (a Unix file system feature) are excluded -# from the input. -# The default value is: NO. - -EXCLUDE_SYMLINKS = NO - -# If the value of the INPUT tag contains directories, you can use the -# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude -# certain files from those directories. -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories for example use the pattern */test/* - -EXCLUDE_PATTERNS = - -# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names -# (namespaces, classes, functions, etc.) that should be excluded from the -# output. The symbol name can be a fully qualified name, a word, or if the -# wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* - -EXCLUDE_SYMBOLS = - -# The EXAMPLE_PATH tag can be used to specify one or more files or directories -# that contain example code fragments that are included (see the \include -# command). - -EXAMPLE_PATH = - -# If the value of the EXAMPLE_PATH tag contains directories, you can use the -# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank all -# files are included. - -EXAMPLE_PATTERNS = * - -# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be -# searched for input files to be used with the \include or \dontinclude commands -# irrespective of the value of the RECURSIVE tag. -# The default value is: NO. - -EXAMPLE_RECURSIVE = NO - -# The IMAGE_PATH tag can be used to specify one or more files or directories -# that contain images that are to be included in the documentation (see the -# \image command). - -IMAGE_PATH = - -# The INPUT_FILTER tag can be used to specify a program that doxygen should -# invoke to filter for each input file. Doxygen will invoke the filter program -# by executing (via popen()) the command: -# -# -# -# where is the value of the INPUT_FILTER tag, and is the -# name of an input file. Doxygen will then use the output that the filter -# program writes to standard output. If FILTER_PATTERNS is specified, this tag -# will be ignored. -# -# Note that the filter must not add or remove lines; it is applied before the -# code is scanned, but not when the output code is generated. If lines are added -# or removed, the anchors will not be placed correctly. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -INPUT_FILTER = - -# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern -# basis. Doxygen will compare the file name with each pattern and apply the -# filter if there is a match. The filters are a list of the form: pattern=filter -# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how -# filters are used. If the FILTER_PATTERNS tag is empty or if none of the -# patterns match the file name, INPUT_FILTER is applied. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -FILTER_PATTERNS = - -# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using -# INPUT_FILTER) will also be used to filter the input files that are used for -# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). -# The default value is: NO. - -FILTER_SOURCE_FILES = NO - -# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file -# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and -# it is also possible to disable source filtering for a specific pattern using -# *.ext= (so without naming a filter). -# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. - -FILTER_SOURCE_PATTERNS = - -# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that -# is part of the input, its contents will be placed on the main page -# (index.html). This can be useful if you have a project on for instance GitHub -# and want to reuse the introduction page also for the doxygen output. - -USE_MDFILE_AS_MAINPAGE = - -#--------------------------------------------------------------------------- -# Configuration options related to source browsing -#--------------------------------------------------------------------------- - -# If the SOURCE_BROWSER tag is set to YES then a list of source files will be -# generated. Documented entities will be cross-referenced with these sources. -# -# Note: To get rid of all source code in the generated output, make sure that -# also VERBATIM_HEADERS is set to NO. -# The default value is: NO. - -SOURCE_BROWSER = NO - -# Setting the INLINE_SOURCES tag to YES will include the body of functions, -# classes and enums directly into the documentation. -# The default value is: NO. - -INLINE_SOURCES = NO - -# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any -# special comment blocks from generated source code fragments. Normal C, C++ and -# Fortran comments will always remain visible. -# The default value is: YES. - -STRIP_CODE_COMMENTS = YES - -# If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. -# The default value is: NO. - -REFERENCED_BY_RELATION = NO - -# If the REFERENCES_RELATION tag is set to YES then for each documented function -# all documented entities called/used by that function will be listed. -# The default value is: NO. - -REFERENCES_RELATION = NO - -# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set -# to YES then the hyperlinks from functions in REFERENCES_RELATION and -# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will -# link to the documentation. -# The default value is: YES. - -REFERENCES_LINK_SOURCE = YES - -# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the -# source code will show a tooltip with additional information such as prototype, -# brief description and links to the definition and documentation. Since this -# will make the HTML file larger and loading of large files a bit slower, you -# can opt to disable this feature. -# The default value is: YES. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -SOURCE_TOOLTIPS = YES - -# If the USE_HTAGS tag is set to YES then the references to source code will -# point to the HTML generated by the htags(1) tool instead of doxygen built-in -# source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version -# 4.8.6 or higher. -# -# To use it do the following: -# - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file -# - Make sure the INPUT points to the root of the source tree -# - Run doxygen as normal -# -# Doxygen will invoke htags (and that will in turn invoke gtags), so these -# tools must be available from the command line (i.e. in the search path). -# -# The result: instead of the source browser generated by doxygen, the links to -# source code will now point to the output of htags. -# The default value is: NO. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -USE_HTAGS = NO - -# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a -# verbatim copy of the header file for each class for which an include is -# specified. Set to NO to disable this. -# See also: Section \class. -# The default value is: YES. - -VERBATIM_HEADERS = YES - -#--------------------------------------------------------------------------- -# Configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- - -# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all -# compounds will be generated. Enable this if the project contains a lot of -# classes, structs, unions or interfaces. -# The default value is: YES. - -ALPHABETICAL_INDEX = YES - -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -IGNORE_PREFIX = - -#--------------------------------------------------------------------------- -# Configuration options related to the HTML output -#--------------------------------------------------------------------------- - -# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output -# The default value is: YES. - -GENERATE_HTML = YES - -# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a -# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of -# it. -# The default directory is: html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_OUTPUT = html - -# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each -# generated HTML page (for example: .htm, .php, .asp). -# The default value is: .html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FILE_EXTENSION = .html - -# The HTML_HEADER tag can be used to specify a user-defined HTML header file for -# each generated HTML page. If the tag is left blank doxygen will generate a -# standard header. -# -# To get valid HTML the header file that includes any scripts and style sheets -# that doxygen needs, which is dependent on the configuration options used (e.g. -# the setting GENERATE_TREEVIEW). It is highly recommended to start with a -# default header using -# doxygen -w html new_header.html new_footer.html new_stylesheet.css -# YourConfigFile -# and then modify the file new_header.html. See also section "Doxygen usage" -# for information on how to generate the default header that doxygen normally -# uses. -# Note: The header is subject to change so you typically have to regenerate the -# default header when upgrading to a newer version of doxygen. For a description -# of the possible markers and block names see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_HEADER = - -# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each -# generated HTML page. If the tag is left blank doxygen will generate a standard -# footer. See HTML_HEADER for more information on how to generate a default -# footer and what special commands can be used inside the footer. See also -# section "Doxygen usage" for information on how to generate the default footer -# that doxygen normally uses. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FOOTER = - -# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style -# sheet that is used by each HTML page. It can be used to fine-tune the look of -# the HTML output. If left blank doxygen will generate a default style sheet. -# See also section "Doxygen usage" for information on how to generate the style -# sheet that doxygen normally uses. -# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as -# it is more robust and this tag (HTML_STYLESHEET) will in the future become -# obsolete. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_STYLESHEET = - -# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined -# cascading style sheets that are included after the standard style sheets -# created by doxygen. Using this option one can overrule certain style aspects. -# This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefore more robust against future updates. -# Doxygen will copy the style sheet files to the output directory. -# Note: The order of the extra style sheet files is of importance (e.g. the last -# style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_STYLESHEET = - -# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or -# other source files which should be copied to the HTML output directory. Note -# that these files will be copied to the base HTML output directory. Use the -# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these -# files. In the HTML_STYLESHEET file, use the file name only. Also note that the -# files will be copied as-is; there are no commands or markers available. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_FILES = - -# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen -# will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value -# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 -# purple, and 360 is red again. -# Minimum value: 0, maximum value: 359, default value: 220. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_HUE = 220 - -# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A -# value of 255 will produce the most vivid colors. -# Minimum value: 0, maximum value: 255, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_SAT = 100 - -# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the -# luminance component of the colors in the HTML output. Values below 100 -# gradually make the output lighter, whereas values above 100 make the output -# darker. The value divided by 100 is the actual gamma applied, so 80 represents -# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not -# change the gamma. -# Minimum value: 40, maximum value: 240, default value: 80. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_GAMMA = 80 - -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - -# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML -# documentation will contain sections that can be hidden and shown after the -# page has loaded. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_SECTIONS = NO - -# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries -# shown in the various tree structured indices initially; the user can expand -# and collapse entries dynamically later on. Doxygen will expand the tree to -# such a level that at most the specified number of entries are visible (unless -# a fully collapsed tree already exceeds this amount). So setting the number of -# entries 1 will produce a full collapsed tree by default. 0 is a special value -# representing an infinite number of entries and will result in a full expanded -# tree by default. -# Minimum value: 0, maximum value: 9999, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_INDEX_NUM_ENTRIES = 100 - -# If the GENERATE_DOCSET tag is set to YES, additional index files will be -# generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in -# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_DOCSET = NO - -# This tag determines the name of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# The default value is: Doxygen generated docs. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDNAME = "Doxygen generated docs" - -# This tag specifies a string that should uniquely identify the documentation -# set bundle. This should be a reverse domain-name style string, e.g. -# com.mycompany.MyDocSet. Doxygen will append .docset to the name. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_BUNDLE_ID = org.doxygen.Project - -# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify -# the documentation publisher. This should be a reverse domain-name style -# string, e.g. com.mycompany.MyDocSet.documentation. -# The default value is: org.doxygen.Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_ID = org.doxygen.Publisher - -# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. -# The default value is: Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_NAME = Publisher - -# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three -# additional HTML index files: index.hhp, index.hhc, and index.hhk. The -# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. -# -# The HTML Help Workshop contains a compiler that can convert all HTML output -# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML -# files are now used as the Windows 98 help format, and will replace the old -# Windows help format (.hlp) on all Windows platforms in the future. Compressed -# HTML files also contain an index, a table of contents, and you can search for -# words in the documentation. The HTML workshop also contains a viewer for -# compressed HTML files. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_HTMLHELP = NO - -# The CHM_FILE tag can be used to specify the file name of the resulting .chm -# file. You can add a path in front of the file if the result should not be -# written to the html output directory. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_FILE = - -# The HHC_LOCATION tag can be used to specify the location (absolute path -# including file name) of the HTML help compiler (hhc.exe). If non-empty, -# doxygen will try to run the HTML help compiler on the generated index.hhp. -# The file has to be specified with full path. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -HHC_LOCATION = - -# The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -GENERATE_CHI = NO - -# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) -# and project file content. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_INDEX_ENCODING = - -# The BINARY_TOC flag controls whether a binary table of contents is generated -# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it -# enables the Previous and Next buttons. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -BINARY_TOC = NO - -# The TOC_EXPAND flag can be set to YES to add extra items for group members to -# the table of contents of the HTML help documentation and to the tree view. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -TOC_EXPAND = NO - -# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and -# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that -# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help -# (.qch) of the generated HTML documentation. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_QHP = NO - -# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify -# the file name of the resulting .qch file. The path specified is relative to -# the HTML output folder. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QCH_FILE = - -# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help -# Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_NAMESPACE = org.doxygen.Project - -# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt -# Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). -# The default value is: doc. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_VIRTUAL_FOLDER = doc - -# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom -# filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_NAME = - -# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the -# custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_ATTRS = - -# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this -# project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_SECT_FILTER_ATTRS = - -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHG_LOCATION = - -# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be -# generated, together with the HTML files, they form an Eclipse help plugin. To -# install this plugin and make it available under the help contents menu in -# Eclipse, the contents of the directory containing the HTML and XML files needs -# to be copied into the plugins directory of eclipse. The name of the directory -# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. -# After copying Eclipse needs to be restarted before the help appears. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_ECLIPSEHELP = NO - -# A unique identifier for the Eclipse help plugin. When installing the plugin -# the directory name containing the HTML and XML files should also have this -# name. Each documentation set should have its own identifier. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. - -ECLIPSE_DOC_ID = org.doxygen.Project - -# If you want full control over the layout of the generated HTML pages it might -# be necessary to disable the index and replace it with your own. The -# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top -# of each HTML page. A value of NO enables the index and the value YES disables -# it. Since the tabs in the index contain the same information as the navigation -# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -DISABLE_INDEX = NO - -# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index -# structure should be generated to display hierarchical information. If the tag -# value is set to YES, a side panel will be generated containing a tree-like -# index structure (just like the one that is generated for HTML Help). For this -# to work a browser that supports JavaScript, DHTML, CSS and frames is required -# (i.e. any modern browser). Windows users are probably better off using the -# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_TREEVIEW = NO - -# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that -# doxygen will group on one line in the generated HTML documentation. -# -# Note that a value of 0 will completely suppress the enum values from appearing -# in the overview section. -# Minimum value: 0, maximum value: 20, default value: 4. -# This tag requires that the tag GENERATE_HTML is set to YES. - -ENUM_VALUES_PER_LINE = 4 - -# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used -# to set the initial width (in pixels) of the frame in which the tree is shown. -# Minimum value: 0, maximum value: 1500, default value: 250. -# This tag requires that the tag GENERATE_HTML is set to YES. - -TREEVIEW_WIDTH = 250 - -# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to -# external symbols imported via tag files in a separate window. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -EXT_LINKS_IN_WINDOW = NO - -# Use this tag to change the font size of LaTeX formulas included as images in -# the HTML documentation. When you change the font size after a successful -# doxygen run you need to manually remove any form_*.png images from the HTML -# output directory to force them to be regenerated. -# Minimum value: 8, maximum value: 50, default value: 10. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_FONTSIZE = 10 - -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - -# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering -# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX -# installed or if you want to formulas look prettier in the HTML output. When -# enabled you may also need to install MathJax separately and configure the path -# to it using the MATHJAX_RELPATH option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -USE_MATHJAX = NO - -# When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. -# Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. -# The default value is: HTML-CSS. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_FORMAT = HTML-CSS - -# When MathJax is enabled you need to specify the location relative to the HTML -# output directory using the MATHJAX_RELPATH option. The destination directory -# should contain the MathJax.js script. For instance, if the mathjax directory -# is located at the same level as the HTML output directory, then -# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax -# Content Delivery Network so you can quickly see the result without installing -# MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest - -# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax -# extension names that should be enabled during MathJax rendering. For example -# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_EXTENSIONS = - -# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces -# of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an -# example see the documentation. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_CODEFILE = - -# When the SEARCHENGINE tag is enabled doxygen will generate a search box for -# the HTML output. The underlying search engine uses javascript and DHTML and -# should work on any modern browser. Note that when using HTML help -# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) -# there is already a search function so this one should typically be disabled. -# For large projects the javascript based search engine can be slow, then -# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to -# search using the keyboard; to jump to the search box use + S -# (what the is depends on the OS and browser, but it is typically -# , /