diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..5c0f059e --- /dev/null +++ b/.clang-format @@ -0,0 +1,5 @@ +BasedOnStyle: Google +ColumnLimit: 90 +DerivePointerAlignment: false +IndentCaseLabels: false +PointerAlignment: Right diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..a9d4671b --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,73 @@ +# Disable the following checks with reasons in parenthesis: +# +# -bugprone-macro-parentheses (inconsistent style) +# -google-readability-todo (potentially too restrictive) +# -misc-non-private-member-variables-in-classes (potentially too restrictive) +# -misc-unused-parameters (can be cleaned up in batch and enabled) +# -modernize-avoid-c-arrays (too restrictive) +# -modernize-concat-nested-namespaces (inconsistent style) +# -modernize-pass-by-value (too restrictive) +# -modernize-return-braced-init-list (inconsistent style) +# -modernize-use-emplace (more subtle behavior) +# -modernize-use-nodiscard (too much noise) +# -modernize-use-trailing-return-type (inconsistent style) +# -modernize-use-override (TODO(mwtian): re-enable after fixing existing derived classes) +# -modernize-avoid-bind (incorrect conversion) +# -modernize-loop-convert (more subtle behavior) +# -modernize-replace-disallow-copy-and-assign-macro (inconsistent style) +# -modernize-make-unique (doesn't work with private constructor) +# -modernize-make-shared (doesn't work with private constructor) +# Other readability-* rules (potentially too noisy, inconsistent style) +# Other rules not mentioned here or below (not yet evaluated) +# +# TODO: enable google-* and readability-* families of checks. +Checks: > + abseil-*, + bugprone-*, + -bugprone-macro-parentheses, + google-*, + -google-readability-todo, + misc-*, + -misc-non-private-member-variables-in-classes, + -misc-unused-parameters, + modernize-*, + -modernize-avoid-c-arrays, + -modernize-concat-nested-namespaces, + -modernize-pass-by-value, + -modernize-return-braced-init-list, + -modernize-use-emplace, + -modernize-use-nodiscard, + -modernize-use-trailing-return-type, + -modernize-avoid-bind, + -modernize-loop-convert, + -modernize-replace-disallow-copy-and-assign-macro, + -modernize-make-unique, + -modernize-make-shared, + -modernize-use-override, + performance-*, + readability-avoid-const-params-in-decls, + readability-braces-around-statements, + readability-const-return-type, + readability-container-size-empty, + readability-delete-null-pointer, + readability-else-after-return, + readability-implicit-bool-conversion, + readability-make-member-function-const, + readability-misleading-indentation, + readability-misplaced-array-index, + readability-named-parameter, + readability-non-const-parameter, + readability-redundant-*, + readability-static-definition-in-anonymous-namespace, + readability-string-compare, + readability-suspicious-call-argument, + +CheckOptions: + # Reduce noisiness of the bugprone-narrowing-conversions check. + - key: bugprone-narrowing-conversions.IgnoreConversionFromTypes + value: 'size_t;ptrdiff_t;size_type;difference_type' + - key: bugprone-narrowing-conversions.WarnOnEquivalentBitWidth + value: 'false' + +# Turn all the warnings from the checks above into errors. +WarningsAsErrors: "*" diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..1142bbd4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,198 @@ +# Bazel +bazel-bin/ +bazel-out/ +bazel-streaming/ +bazel-test/ + +# Redis temporary files +*dump.rdb + +# Python byte code files +*.pyc +python/.eggs + +# Backup files +*.bak + +# Emacs temporary files +*~ +*# + +# Compiled Object files +*.slo +*.lo +*.o +*.xo +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll +python/ray/_raylet.pyd + +# Incremental linking files +*.ilk + +# Library export files +*.exp + +# Debug symbols +*.pdb + +# Fortran module files +*.mod +!deploy/ray-operator/go.mod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# Visual Studio files +/packages +*.suo +*.user +*.VC.db +*.VC.opendb + +# Protobuf-generated files +*_pb2.py +*.pb.h +*.pb.cc + +# Ray cluster configuration +scripts/nodes.txt + +# OS X folder attributes +.DS_Store + +# Debug files +*.dSYM/ +*.su + +# Python setup files +*.egg-info + +# Compressed files +*.gz + +# Datasets from examples +**/MNIST_data/ +**/cifar-10-batches-bin/ + +# Generated documentation files +/doc/_build +/doc/source/_static/thumbs +/doc/source/tune/generated_guides/ + +# User-specific stuff: +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries +.llvm-local.bazelrc + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries +.idea + +# Website +/site/Gemfile.lock +/site/.sass-cache +/site/_site + +# Pytest Cache +**/.pytest_cache +**/.cache +.benchmarks +python-driver-* + +# Vscode +.vscode/ + +*.iml + +# Java +java/**/target +java/**/lib +java/**/.settings +java/**/.classpath +java/**/.project +java/runtime/native_dependencies/ + +dependency-reduced-pom.xml + +# Cpp +cpp/example/thirdparty/ + +# streaming/python +streaming/python/generated/ +streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/generated/ +streaming/build/java +.clwb +streaming/**/.settings +streaming/java/**/target +streaming/java/**/.classpath +streaming/java/**/.project +streaming/java/**/*.log + +# pom.xml files generated from pom_template.xml +java/**/pom.xml +streaming/java/**/pom.xml + +# python virtual env +venv + +# pyenv version file +.python-version + +# Vim +.*.swp +*.swp +tags +tags.lock +tags.temp + +# Emacs +.#* + +# tools +tools/prometheus* + +# ray project files +project-id +.mypy_cache/ + +# gitpod cache related +.pip-cache/ +.bazel-cache/ + +# release test related +.anyscale.yaml +test_state.json + +# workflow storage +workflow_data/ + +# vscode java extention generated +.factorypath diff --git a/README.md b/README.md index 0d3216fe..b6492d9a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# Mobius online learning. +# Mobius : online machine learning. -## Code Editor -Recommend Atom (https://atom.io/) as the code editor. -Recommended plugins: atom-beautify, python-indent, auto-indent, vim-mode-plus (for VIM-ers) +## Streaming + +## Training diff --git a/streaming/.bazelrc b/streaming/.bazelrc new file mode 100644 index 00000000..6c70b2e9 --- /dev/null +++ b/streaming/.bazelrc @@ -0,0 +1,192 @@ +# Must be first. Enables build:windows, build:linux, build:macos, build:freebsd, build:openbsd +build --enable_platform_specific_config +############################################################################### +# On Windows, provide: BAZEL_SH, and BAZEL_LLVM (if using clang-cl) +# On all platforms, provide: PYTHON3_BIN_PATH=python +############################################################################### +build --action_env=PATH +# For --compilation_mode=dbg, consider enabling checks in the standard library as well (below). +build --compilation_mode=opt +# Using C++ 17 on all platforms. +build:linux --cxxopt="-std=c++17" +build:macos --cxxopt="-std=c++17" +build:clang-cl --cxxopt="-std=c++17" +build:msvc-cl --cxxopt="/std:c++17" +build:windows --cxxopt="/std:c++17" +# This workaround is needed to prevent Bazel from compiling the same file twice (once PIC and once not). +build:linux --force_pic +build:macos --force_pic +build:clang-cl --compiler=clang-cl +build:msvc-cl --compiler=msvc-cl +# `LC_ALL` and `LANG` is needed for cpp worker tests, because they will call "ray start". +# If we don't add them, python's `click` library will raise an error. +build --action_env=LC_ALL +build --action_env=LANG +# Allow C++ worker tests to execute "ray start" with the correct version of Python. +build --action_env=VIRTUAL_ENV +build --action_env=PYENV_VIRTUAL_ENV +build --action_env=PYENV_VERSION +build --action_env=PYENV_SHELL +# This is needed for some core tests to run correctly +build:windows --enable_runfiles +# TODO(mehrdadn): Revert the "-\\.(asm|S)$" exclusion when this Bazel bug +# for compiling assembly files is fixed on Windows: +# https://github.com/bazelbuild/bazel/issues/8924 +# Warnings should be errors +build:linux --per_file_copt="-\\.(asm|S)$@-Werror" +build:macos --per_file_copt="-\\.(asm|S)$@-Werror" +build:clang-cl --per_file_copt="-\\.(asm|S)$@-Werror" +build:msvc-cl --per_file_copt="-\\.(asm|S)$@-WX" +# Ignore warnings for protobuf generated files and external projects. +build --per_file_copt="\\.pb\\.cc$@-w" +build --per_file_copt="-\\.(asm|S)$,external/.*@-w" +#build --per_file_copt="external/.*@-Wno-unused-result" +# Ignore minor warnings for host tools, which we generally can't control +build:clang-cl --host_copt="-Wno-inconsistent-missing-override" +build:clang-cl --host_copt="-Wno-microsoft-unqualified-friend" +# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341 +build --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD" +# Don't generate warnings about kernel features we don't need https://github.com/ray-project/ray/issues/6832 +build:linux --per_file_copt="-\\.(asm|S)$,external/com_github_grpc_grpc/.*@-DGPR_MANYLINUX1" +# Ignore wchar_t -> char conversion warning on MSVC +build:msvc-cl --per_file_copt="external/boost/libs/regex/src/wc_regex_traits\\.cpp@-wd4244" +build --http_timeout_scaling=5.0 +build --verbose_failures +build:iwyu --experimental_action_listener=//:iwyu_cpp + +# Print relative paths when possible +build:windows --attempt_to_print_relative_paths +# Save disk space by hardlinking cache hits instead of copying +build:windows --experimental_repository_cache_hardlinks +# Clean the environment before building, to make builds more deterministic +build:windows --incompatible_strict_action_env +# For colored output (seems necessary on Windows) +build:windows --color=yes +# For compiler colored output (seems necessary on Windows) +build:clang-cl --per_file_copt="-\\.(asm|S)$@-fansi-escape-codes" +build:clang-cl --per_file_copt="-\\.(asm|S)$@-fcolor-diagnostics" + +build:manylinux2010 --copt="-Wno-unused-result" +build:manylinux2010 --linkopt="-lrt" + +# Debug build flags. Uncomment in '-c dbg' builds to enable checks in the C++ standard library: +#build:linux --cxxopt="-D_GLIBCXX_DEBUG=1" +#build:linux --cxxopt="-D_GLIBCXX_DEBUG_PEDANTIC=1" +#build:linux --cxxopt="-D_LIBCPP_DEBUG=1" +#build:macos --cxxopt="-D_GLIBCXX_DEBUG=1" +#build:macos --cxxopt="-D_GLIBCXX_DEBUG_PEDANTIC=1" +#build:windows --cxxopt="-D_ITERATOR_DEBUG_LEVEL=2" + +# LLVM (clang & libc++) build flags common across Linux installations and systems. +# On Ubuntu, the remaining configurations can be generated by running ci/travis/install-llvm-binaries.sh +build:llvm --action_env=CXXFLAGS=-stdlib=libc++ +build:llvm --action_env=LDFLAGS=-stdlib=libc++ +build:llvm --action_env=BAZEL_CXXOPTS=-stdlib=libc++ +build:llvm --action_env=BAZEL_LINKLIBS=-l%:libc++.a:-l%:libc++abi.a +build:llvm --action_env=BAZEL_LINKOPTS=-lm:-pthread +build:llvm --define force_libcpp=enabled + +# Thread sanitizer configuration: +build:tsan --strip=never +build:tsan --copt -fsanitize=thread +build:tsan --copt -DTHREAD_SANITIZER +build:tsan --copt -O2 +build:tsan --copt -g +build:tsan --copt -fno-omit-frame-pointer +build:tsan --copt -Wno-uninitialized +build:tsan --linkopt -fsanitize=thread +# This config is only for running TSAN with LLVM toolchain on Linux. +build:tsan-clang --config=tsan +build:tsan-clang --config=llvm +test:tsan --test_env=TSAN_OPTIONS="report_atomic_races=0" + +# Memory sanitizer configuration +build:asan --strip=never +build:asan --copt -g +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER +build:asan --copt -fno-omit-frame-pointer +build:asan --linkopt -fsanitize=address +test:asan --jobs=1 +test:asan --test_env=ASAN_OPTIONS="detect_leaks=0" +# This config is only for running ASAN with LLVM toolchain on Linux. +# https://github.com/google/sanitizers/issues/1017 +build:asan-clang --config=asan +build:asan-clang --config=llvm +build:asan-clang --copt -mllvm +build:asan-clang --copt -asan-use-private-alias=1 + +build:asan-build --strip=never +build:asan-build -c dbg +build:asan-build --copt -fsanitize=address +build:asan-build --copt -DADDRESS_SANITIZER +build:asan-build --copt -O1 +build:asan-build --copt -g +build:asan-build --copt -fno-omit-frame-pointer +build:asan-build --copt -static-libasan +build:asan-build --linkopt -fsanitize=address +build:asan-build --linkopt -static-libasan +# For example, for Ubuntu 18.04 libasan can be found here: +# test:asan --test_env=LD_PRELOAD="/usr/lib/gcc/x86_64-linux-gnu/7/libasan.so" +test:asan-buildkite --test_env=LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libasan.so.5" + +# CI configuration: +aquery:ci --color=no +aquery:ci --noshow_progress +build:ci --color=yes +build:ci --curses=no +build:ci --keep_going +build:ci --progress_report_interval=100 +build:ci --show_progress_rate_limit=15 +build:ci --show_task_finish +build:ci --ui_actions_shown=1024 +build:ci --show_timestamps +build:ci-travis --disk_cache=~/ray-bazel-cache +build:ci-travis --remote_cache="https://storage.googleapis.com/ray-bazel-cache" +build:ci-github --experimental_repository_cache_hardlinks # GitHub Actions has low disk space, so prefer hardlinks there. +build:ci-github --disk_cache=~/ray-bazel-cache +build:ci-github --remote_cache="https://storage.googleapis.com/ray-bazel-cache" +test:ci --flaky_test_attempts=3 +# Disable test result caching because py_test under Bazel can import from outside of sandbox, but Bazel only looks at +# declared dependencies to determine if a result should be cached. More details at: +# https://github.com/bazelbuild/bazel/issues/7091, https://github.com/bazelbuild/rules_python/issues/382 +test:ci --nocache_test_results +test:ci --spawn_strategy=local +test:ci --test_output=errors +test:ci --test_verbose_timeout_warnings +test:ci-debug -c dbg +test:ci-debug --copt="-g" +test:ci-debug --flaky_test_attempts=3 +# Disable test result caching for the same reason above. +test:ci-debug --nocache_test_results +test:ci-debug --spawn_strategy=local +test:ci-debug --test_output=errors +test:ci-debug --test_verbose_timeout_warnings + +aquery:get-toolchain --include_commandline=false +aquery:get-toolchain --noimplicit_deps + +# [Linux] Uncomment this line (or use --config) to print a stack trace on exit. +#test:linux --config=strace +test:strace --run_under="bash -c 'if command -v strace >/dev/null && strace -qq -k -e exit true 2>/dev/null; then strace -qq -k -e exit -e trace=\"!all\" -s 32768 -f -o >(awk \"/^[0-9]+ / { y = \\$3 != \\\"SIGCHLD\\\" && \\$3 != \\\"SIGTERM\\\" && \\$5 != \\\"SIGTERM\\\" && \\$5 != \\\"SIGKILL2\\\"; } y { print; }\" 1>&2 && cat 1>&2) -- \"$@\"; else \"$@\"; fi' -" +# [Linux] Uncomment this line (or use --config) to preload libSegFault.so if available, to print a stack trace on aborts and segfault. (Note: This doesn't always work.) +#test:linux --config=segfault +test:segfault --run_under="bash -c 'unset GREP_OPTIONS && if ! grep -q -o Microsoft /proc/version 2>/dev/null; then libs=\"$(command -v ldconfig >/dev/null && ldconfig -p | grep -F -o -e \"libSegFault.so\" | uniq | tr \"\\\\n\" :)\" && if [ -n \"${libs%:}\" ]; then export SEGFAULT_SIGNALS=\"abrt segv\" LD_PRELOAD=\"${libs}${LD_PRELOAD-}\"; fi; fi && \"$@\"' -" + +# Debug build: +build:debug -c dbg +build:debug --copt="-g" +build:debug --strip="never" + +# Undefined Behavior Sanitizer +build:ubsan --strip=never +build:ubsan --copt -fsanitize=undefined +build:ubsan --copt -fno-sanitize=vptr +build:ubsan --copt -fno-sanitize-recover=all +build:ubsan --copt -g +build:ubsan --linkopt -fsanitize=undefined +build:ubsan --linkopt -fno-sanitize-recover=all + +# Import local specific llvm config options, which can be generated by +# ci/travis/install-llvm-dependencies.sh +try-import %workspace%/.llvm-local.bazelrc diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel new file mode 100644 index 00000000..6b8458fb --- /dev/null +++ b/streaming/BUILD.bazel @@ -0,0 +1,463 @@ +# Bazel build +# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html + +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") +load("@com_github_ray_project_ray//bazel:ray.bzl", "COPTS", "copy_to_workspace") +load("@rules_python//python:defs.bzl", "py_library") + +package( + default_visibility = ["//visibility:public"], +) + +config_setting( + name = "msvc-cl", + flag_values = {"@bazel_tools//tools/cpp:compiler": "msvc-cl"}, +) + +config_setting( + name = "clang-cl", + flag_values = {"@bazel_tools//tools/cpp:compiler": "clang-cl"}, +) + +config_setting( + name = "opt", + values = {"compilation_mode": "opt"}, +) + +proto_library( + name = "streaming_proto", + srcs = ["src/protobuf/streaming.proto"], + strip_import_prefix = "src", + visibility = ["//visibility:public"], +) + +proto_library( + name = "streaming_queue_proto", + srcs = ["src/protobuf/streaming_queue.proto"], + strip_import_prefix = "src", +) + +proto_library( + name = "remote_call_proto", + srcs = ["src/protobuf/remote_call.proto"], + strip_import_prefix = "src", + visibility = ["//visibility:public"], + deps = [ + "streaming_proto", + "@com_google_protobuf//:any_proto", + ], +) + +cc_proto_library( + name = "streaming_cc_proto", + deps = [":streaming_proto"], +) + +cc_proto_library( + name = "streaming_queue_cc_proto", + deps = ["streaming_queue_proto"], +) + +# Use `linkshared` to ensure ray related symbols are not packed into streaming libs +# to avoid duplicate symbols. In runtime we expose ray related symbols, which can +# be linked into streaming libs by dynamic linker. See bazel rule `//:_raylet` +cc_binary( + name = "ray_util.so", + copts = COPTS, + linkshared = 1, + visibility = ["//visibility:public"], + deps = ["@com_github_ray_project_ray//:ray_util"], +) + +cc_binary( + name = "ray_common.so", + copts = COPTS, + linkshared = 1, + visibility = ["//visibility:public"], + deps = ["@com_github_ray_project_ray//:ray_common"], +) + +cc_binary( + name = "stats_lib.so", + copts = COPTS, + linkshared = 1, + visibility = ["//visibility:public"], + deps = ["@com_github_ray_project_ray//:stats_lib"], +) + +cc_binary( + name = "core_worker_lib.so", + copts = COPTS, + linkshared = 1, + deps = ["@com_github_ray_project_ray//:core_worker_lib"], +) + +cc_binary( + name = "exported_streaming_internal.so", + copts = COPTS, + linkshared = 1, + deps = ["@com_github_ray_project_ray//:exported_streaming_internal"], +) + +cc_library( + name = "streaming_util", + srcs = glob([ + "src/util/*.cc", + ]), + hdrs = glob([ + "src/util/*.h", + ]), + copts = COPTS, + includes = ["src"], + visibility = ["//visibility:public"], + deps = [ + "ray_common.so", + "ray_util.so", + "@boost//:any", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "streaming_metrics", + srcs = glob([ + "src/metrics/*.cc", + ]), + hdrs = glob([ + "src/metrics/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + visibility = ["//visibility:public"], + deps = [ + "stats_lib.so", + ":streaming_config", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_config", + srcs = glob([ + "src/config/*.cc", + ]), + hdrs = glob([ + "src/config/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + deps = [ + "ray_common.so", + ":streaming_cc_proto", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_message", + srcs = glob([ + "src/message/*.cc", + ]), + hdrs = glob([ + "src/message/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + deps = [ + "ray_common.so", + ":streaming_config", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_queue", + srcs = glob([ + "src/queue/*.cc", + ]), + hdrs = glob([ + "src/queue/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + deps = [ + "ray_common.so", + "ray_util.so", + ":streaming_config", + ":streaming_message", + ":streaming_queue_cc_proto", + ":streaming_util", + "@boost//:asio", + "@boost//:thread", + ] + select({ + "@bazel_tools//src/conditions:windows": [ + # TODO(mehrdadn): This is to resolve symbols on Windows for now. Should remove this later. (See d7f8d18.) + "@com_github_ray_project_ray//:core_worker_lib", + "@com_github_ray_project_ray//:exported_streaming_internal", + ], + "//conditions:default": [ + "core_worker_lib.so", + "exported_streaming_internal.so", + ], + }), +) + +cc_library( + name = "streaming_channel", + srcs = glob(["src/channel/*.cc"]), + hdrs = glob(["src/channel/*.h"]), + copts = COPTS, + visibility = ["//visibility:public"], + deps = [ + ":streaming_common", + ":streaming_message", + ":streaming_queue", + ":streaming_ring_buffer", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_reliability", + srcs = glob(["src/reliability/*.cc"]), + hdrs = glob(["src/reliability/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [ + ":streaming_channel", + ":streaming_message", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_ring_buffer", + srcs = glob(["src/ring_buffer/*.cc"]), + hdrs = glob(["src/ring_buffer/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [ + "core_worker_lib.so", + ":ray_common.so", + ":ray_util.so", + ":streaming_message", + "@boost//:circular_buffer", + "@boost//:thread", + ], +) + +cc_library( + name = "streaming_common", + srcs = glob(["src/common/*.cc"]), + hdrs = glob(["src/common/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [], +) + +cc_library( + name = "streaming_lib", + srcs = glob([ + "src/*.cc", + ]), + hdrs = glob([ + "src/*.h", + "src/queue/*.h", + "src/test/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + visibility = ["//visibility:public"], + deps = [ + "ray_common.so", + "ray_util.so", + ":streaming_channel", + ":streaming_common", + ":streaming_config", + ":streaming_message", + ":streaming_metrics", + ":streaming_queue", + ":streaming_reliability", + ":streaming_util", + ], +) + +test_common_deps = [ + "@com_github_ray_project_ray//:exported_streaming_internal", + ":streaming_lib", + "@com_github_ray_project_ray//:ray_common", + "@com_github_ray_project_ray//:ray_util", + "@com_github_ray_project_ray//:core_worker_lib", +] + +# streaming queue mock actor binary +cc_binary( + name = "streaming_test_worker", + srcs = glob(["src/test/*.h"]) + [ + "src/test/mock_actor.cc", + ], + copts = COPTS, + deps = test_common_deps, +) + +# use src/test/run_streaming_queue_test.sh to run this test +cc_binary( + name = "streaming_queue_tests", + srcs = glob(["src/test/*.h"]) + [ + "src/test/streaming_queue_tests.cc", + ], + copts = COPTS, + deps = test_common_deps, +) + +cc_test( + name = "streaming_message_ring_buffer_tests", + srcs = [ + "src/test/ring_buffer_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "barrier_helper_tests", + srcs = [ + "src/test/barrier_helper_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "streaming_message_serialization_tests", + srcs = [ + "src/test/message_serialization_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "streaming_mock_transfer", + srcs = [ + "src/test/mock_transfer_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "streaming_util_tests", + srcs = [ + "src/test/streaming_util_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "streaming_perf_tests", + srcs = [ + "src/test/streaming_perf_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "event_service_tests", + srcs = [ + "src/test/event_service_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "queue_protobuf_tests", + srcs = [ + "src/test/queue_protobuf_tests.cc", + ], + tags = ["team:ant-group"], + deps = test_common_deps, +) + +cc_test( + name = "data_writer_tests", + srcs = [ + "src/test/data_writer_tests.cc", + ], + copts = COPTS, + tags = ["team:ant-group"], + deps = test_common_deps, +) + +python_proto_compile( + name = "streaming_py_proto", + deps = [":streaming_proto"], +) + +python_proto_compile( + name = "remote_call_py_proto", + deps = [":remote_call_proto"], +) + +filegroup( + name = "all_py_proto", + srcs = [ + ":remote_call_py_proto", + ":streaming_py_proto", + ], +) + +copy_to_workspace( + name = "cp_all_py_proto", + srcs = [":all_py_proto"], + dstdir = "streaming/python/generated", +) + +genrule( + name = "copy_streaming_py_proto", + srcs = [ + ":cp_all_py_proto", + ], + outs = [ + "copy_streaming_py_proto.out", + ], + cmd = """ + GENERATED_DIR="streaming/python/generated" + mkdir -p "$$GENERATED_DIR" + touch "$$GENERATED_DIR/__init__.py" + # Use this `sed` command to change the import path in the generated file. + sed -i -E 's/from streaming.src.protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" + sed -i -E 's/from protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" + date > $@ + """, + local = 1, + visibility = ["//visibility:public"], +) + +cc_binary( + name = "libstreaming_java.so", + srcs = glob([ + "src/lib/java/*.cc", + "src/lib/java/*.h", + ]), + copts = COPTS, + linkshared = 1, + visibility = ["//visibility:public"], + deps = [ + ":streaming_lib", + "@bazel_tools//tools/jdk:jni", + ], +) diff --git a/streaming/README.rst b/streaming/README.rst new file mode 100644 index 00000000..90e4a676 --- /dev/null +++ b/streaming/README.rst @@ -0,0 +1,232 @@ + +Ray Streaming +============= + +Ray Streaming is a streaming data processing framework built on ray. It will be helpful for you to build jobs dealing with real-time data. + +Key Features +------------ + + +#. + **Cross Language**. Based on Ray's multi-language actor, Ray Streaming can also run in multiple + languages(only Python and Java is supported currently) with high efficiency. You can implement your + operator in different languages and run them in one job. + +#. + **Single Node Failover**. We designed a special failover mechanism that only needs to rollback the + failed node it's own, in most cases, to recover the job. This will be a huge benefit if your job is + sensitive about failure recovery time. In other frameworks like Flink, instead, the entire job should + be restarted once a node has failure. + +Examples +-------- + +Python +^^^^^^ + +.. code-block:: Python + + import ray + from ray.streaming import StreamingContext + + ctx = StreamingContext.Builder() \ + .build() + ctx.read_text_file(__file__) \ + .set_parallelism(1) \ + .flat_map(lambda x: x.split()) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .filter(lambda x: "ray" not in x) \ + .sink(lambda x: print("result", x)) + ctx.submit("word_count") + +Java +^^^^ + +.. code-block:: Java + + StreamingContext context = StreamingContext.buildContext(); + List text = Collections.singletonList("hello world"); + DataStreamSource.fromCollection(context, text) + .flatMap((FlatMapFunction) (value, collector) -> { + String[] records = value.split(" "); + for (String record : records) { + collector.collect(new WordAndCount(record, 1)); + } + }) + .filter(pair -> !pair.word.contains("world")) + .keyBy(pair -> pair.word) + .reduce((oldValue, newValue) -> + new WordAndCount(oldValue.word, oldValue.count + newValue.count)) + .sink(result -> System.out.println("sink result=" + result)); + context.execute("testWordCount"); + +Use Java Operators in Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: Python + + import ray + from ray.streaming import StreamingContext + + ctx = StreamingContext.Builder().build() + ctx.from_values("a", "b", "c") \ + .as_java_stream() \ + .map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \ + .filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \ + .as_python_stream() \ + .sink(lambda x: print("result", x)) + ctx.submit("HybridStreamTest") + +Use Python Operators in Java +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: Java + + StreamingContext context = StreamingContext.buildContext(); + DataStreamSource streamSource = + DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c")); + streamSource + .map(x -> x + x) + .asPythonStream() + .map("ray.streaming.tests.test_hybrid_stream", "map_func1") + .filter("ray.streaming.tests.test_hybrid_stream", "filter_func1") + .asJavaStream() + .sink(value -> System.out.println("HybridStream sink=" + value)); + context.execute("HybridStreamTestJob"); + +Installation +------------ + +Python +^^^^^^ + +Ray Streaming is packaged together with Ray, install Ray with: ``pip install ray``\ , +this wheel contains all dependencies your need to run Python streaming, including Java operators supporting. + +Java +^^^^ + +Import Ray Streaming using maven: + +.. code-block:: xml + + + ray-api + io.ray + 1.0.1 + + + ray-runtime + io.ray + 1.0.1 + + + streaming-api + io.ray + 1.0.1 + + + streaming-runtime + io.ray + 1.0.1 + + +Internal Design +--------------- + +Overall Architecture +^^^^^^^^^^^^^^^^^^^^ + + +.. image:: assets/architecture.jpg + :target: assets/architecture.jpg + :alt: architecture + + +Ray Streaming is built on Ray. We use Ray's actor to run everything, and use Ray's direct call for communication. + +There are two main types of actor: job master and job worker. + +When you execute ``context.submit()`` in your driver, we'll first create a job master, then job master will create all job workers needed to run your operator. Then job master will be responsible to coordinate all workers, including checkpoint, failover, etc. + +Check `Ray Streaming Proposal `_ +to get more detailed information about the overall design. + +Fault Tolerance Mechanism +^^^^^^^^^^^^^^^^^^^^^^^^^ + +As mentioned above, different from other frameworks, We designed a special failover mechanism that only needs to rollback the failed node it's own, in most cases, to recover the job. The main idea to achieve this feature is saving messages for each node, and replay them from upstream when node has failure. + +Check `Fault Tolerance Proposal `_ +for more detailed information about our fault tolerance mechanism. + +Development Guides +------------------ + + +#. + Build streaming java + + + * build ray + + * ``bazel build //java:gen_maven_deps`` + * ``cd java && mvn clean install -Dmaven.test.skip=true && cd ..`` + + * build streaming + + * ``bazel build //streaming/java:gen_maven_deps`` + * ``mvn clean install -Dmaven.test.skip=true`` + +#. + Build ray python will build ray streaming python. + +#. + Run examples + + .. code-block:: bash + + # c++ test + cd streaming/ && bazel test ... + sh src/test/run_streaming_queue_test.sh + cd .. + + # python test + pushd python/ray/streaming/ + pushd examples + python simple.py --input-file toy.txt + popd + pushd tests + pytest . + popd + popd + + # java test + cd streaming/java/streaming-runtime + mvn test + + +More Information +---------------- + + +* `Ray Streaming implementation plan `_ +* `Fault Tolerance Proposal `_ +* `Data Transfer Proposal `_ +* `Ray Streaming Proposal `_ +* `Open Source Plan `_ + +Getting Involved +---------------- + +- `Community Slack`_: Join our Slack workspace. +- `GitHub Discussions`_: For discussions about development, questions about usage, and feature requests. +- `GitHub Issues`_: For reporting bugs. + +.. _`GitHub Discussions`: https://github.com/ray-project/ray/discussions +.. _`GitHub Issues`: https://github.com/ray-project/ray/issues +.. _`Community Slack`: https://forms.gle/9TSdDYUgxYs8SA9e8 diff --git a/streaming/WORKSPACE b/streaming/WORKSPACE new file mode 100644 index 00000000..785bfbb0 --- /dev/null +++ b/streaming/WORKSPACE @@ -0,0 +1,36 @@ +workspace(name = "com_github_ray_streaming") + +# LOAD RAY WORKSPACE +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") +http_archive( + name="com_github_ray_project_ray", + strip_prefix = "ray-ray-1.9.2", + urls = ["https://github.com/ray-project/ray/archive/refs/tags/ray-1.9.2.zip"], + sha256 = "96b77d4af6058d7f06c285ac9f6e9dc0c9f1f51bdb7e63e41a61da6ae79a329c", +# build_file = "@com_github_ray_project_ray//BUILD.bazel", +) + + +load("//bazel:streaming_deps_setup.bzl", "streaming_deps_setup") + +streaming_deps_setup() +#load("@com_github_ray_project_ray//bazel:ray_deps_setup.bzl", "ray_deps_setup") + +#ray_deps_setup() +# +load("@com_github_ray_project_ray//bazel:ray_deps_build_all.bzl", "ray_deps_build_all") + +ray_deps_build_all() +# +## This needs to be run after grpc_deps() in ray_deps_build_all() to make +## sure all the packages loaded by grpc_deps() are available. However a +## load() statement cannot be in a function so we put it here. +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") + +grpc_extra_deps() + +load("@bazel_skylib//lib:versions.bzl", "versions") + +# When the bazel version is updated, make sure to update it +# in setup.py as well. +versions.check(minimum_bazel_version = "4.2.1") diff --git a/streaming/assets/architecture.jpg b/streaming/assets/architecture.jpg new file mode 100644 index 00000000..9aab86b9 Binary files /dev/null and b/streaming/assets/architecture.jpg differ diff --git a/streaming/bazel/BUILD b/streaming/bazel/BUILD new file mode 100644 index 00000000..e69de29b diff --git a/streaming/bazel/streaming_deps_setup.bzl b/streaming/bazel/streaming_deps_setup.bzl new file mode 100644 index 00000000..56d029bc --- /dev/null +++ b/streaming/bazel/streaming_deps_setup.bzl @@ -0,0 +1,186 @@ +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") +load("@com_github_ray_project_ray//bazel:ray_deps_setup.bzl", "auto_http_archive") + +def streaming_deps_setup(): + auto_http_archive( + name = "rules_proto_grpc", + url = "https://github.com/rules-proto-grpc/rules_proto_grpc/archive/a74fef39c5fe636580083545f76d1eab74f6450d.tar.gz", + sha256 = "2f6606151ec042e23396f07de9e7dcf6ca9a5db1d2b09f0cc93a7fc7f4008d1b", + ) + + auto_http_archive( + name = "com_github_google_flatbuffers", + url = "https://github.com/google/flatbuffers/archive/63d51afd1196336a7d1f56a988091ef05deb1c62.tar.gz", + sha256 = "3f469032571d324eabea88d7014c05fec8565a5877dbe49b2a52d8d1a0f18e63", + ) + + auto_http_archive( + name = "bazel_common", + url = "https://github.com/google/bazel-common/archive/084aadd3b854cad5d5e754a7e7d958ac531e6801.tar.gz", + sha256 = "a6e372118bc961b182a3a86344c0385b6b509882929c6b12dc03bb5084c775d5", + ) + + auto_http_archive( + name = "bazel_skylib", + strip_prefix = None, + url = "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", + sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", + ) + + http_archive( + name = "com_google_protobuf", + strip_prefix = "protobuf-3.16.0", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.16.0.tar.gz"], + sha256 = "7892a35d979304a404400a101c46ce90e85ec9e2a766a86041bb361f626247f5", + ) + + http_archive( + name = "rules_python", + sha256 = "a30abdfc7126d497a7698c29c46ea9901c6392d6ed315171a6df5ce433aa4502", + strip_prefix = "rules_python-0.6.0", + url = "https://github.com/bazelbuild/rules_python/archive/0.6.0.tar.gz", + ) + + auto_http_archive( + name = "com_github_grpc_grpc", + # NOTE: If you update this, also update @boringssl's hash. + url = "https://github.com/grpc/grpc/archive/refs/tags/v1.38.1.tar.gz", + sha256 = "f60e5b112913bf776a22c16a3053cc02cf55e60bf27a959fd54d7aaf8e2da6e8", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:grpc-cython-copts.patch", + "@com_github_ray_project_ray//thirdparty/patches:grpc-python.patch", + "@com_github_ray_project_ray//thirdparty/patches:grpc-windows-python-header-path.patch", + ], + ) + + auto_http_archive( + # This rule is used by @com_github_nelhage_rules_boost and + # declaring it here allows us to avoid patching the latter. + name = "boost", + build_file = "@com_github_nelhage_rules_boost//:BUILD.boost", + sha256 = "83bfc1507731a0906e387fc28b7ef5417d591429e51e788417fe9ff025e116b1", + url = "https://boostorg.jfrog.io/artifactory/main/release/1.74.0/source/boost_1_74_0.tar.bz2", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:boost-exception-no_warn_typeid_evaluated.patch", + ], + ) + + auto_http_archive( + name = "com_google_googletest", + url = "https://github.com/google/googletest/archive/refs/tags/release-1.11.0.tar.gz", + sha256 = "b4870bf121ff7795ba20d20bcdd8627b8e088f2d1dab299a031c1034eddc93d5", + ) + + #http_archive( + # name = "zlib", + # build_file = "@com_google_protobuf//:third_party/zlib.BUILD", + # sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + # strip_prefix = "zlib-1.2.11", + # urls = [ + # "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", + # "https://zlib.net/zlib-1.2.11.tar.gz", + # ], + #) + + http_archive( + name = "nlohmann_json", + strip_prefix = "json-3.9.1", + urls = ["https://github.com/nlohmann/json/archive/v3.9.1.tar.gz"], + sha256 = "4cf0df69731494668bdd6460ed8cb269b68de9c19ad8c27abc24cd72605b2d5b", + build_file = "@com_github_ray_project_ray//bazel:BUILD.nlohmann_json", + ) + + auto_http_archive( + name = "io_opencensus_cpp", + url = "https://github.com/census-instrumentation/opencensus-cpp/archive/b14a5c0dcc2da8a7fc438fab637845c73438b703.zip", + sha256 = "6592e07672e7f7980687f6c1abda81974d8d379e273fea3b54b6c4d855489b9d", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:opencensus-cpp-harvest-interval.patch", + "@com_github_ray_project_ray//thirdparty/patches:opencensus-cpp-shutdown-api.patch", + ], + ) + + + # OpenCensus depends on Abseil so we have to explicitly pull it in. + # This is how diamond dependencies are prevented. + auto_http_archive( + name = "com_google_absl", + url = "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", + sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f", + ) + + auto_http_archive( + name = "com_github_spdlog", + build_file = "@com_github_ray_project_ray//bazel:BUILD.spdlog", + urls = ["https://github.com/gabime/spdlog/archive/v1.7.0.zip"], + sha256 = "c8f1e1103e0b148eb8832275d8e68036f2fdd3975a1199af0e844908c56f6ea5", + ) + + # OpenCensus depends on jupp0r/prometheus-cpp + auto_http_archive( + name = "com_github_jupp0r_prometheus_cpp", + url = "https://github.com/jupp0r/prometheus-cpp/archive/60eaa4ea47b16751a8e8740b05fe70914c68a480.tar.gz", + sha256 = "ec825b802487ac18b0d98e2e8b7961487b12562f8f82e424521d0a891d9e1373", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:prometheus-windows-headers.patch", + # https://github.com/jupp0r/prometheus-cpp/pull/225 + "@com_github_ray_project_ray//thirdparty/patches:prometheus-windows-zlib.patch", + "@com_github_ray_project_ray//thirdparty/patches:prometheus-windows-pollfd.patch", + ], + ) + + auto_http_archive( + name = "msgpack", + build_file = "@com_github_ray_project_ray//bazel:BUILD.msgpack", + url = "https://github.com/msgpack/msgpack-c/archive/8085ab8721090a447cf98bb802d1406ad7afe420.tar.gz", + sha256 = "83c37c9ad926bbee68d564d9f53c6cbb057c1f755c264043ddd87d89e36d15bb", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:msgpack-windows-iovec.patch", + ], + ) + + auto_http_archive( + # This rule is used by @com_github_nelhage_rules_boost and + # declaring it here allows us to avoid patching the latter. + name = "boost", + build_file = "@com_github_nelhage_rules_boost//:BUILD.boost", + sha256 = "83bfc1507731a0906e387fc28b7ef5417d591429e51e788417fe9ff025e116b1", + url = "https://boostorg.jfrog.io/artifactory/main/release/1.74.0/source/boost_1_74_0.tar.bz2", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:boost-exception-no_warn_typeid_evaluated.patch", + ], + ) + + auto_http_archive( + name = "com_github_nelhage_rules_boost", + # If you update the Boost version, remember to update the 'boost' rule. + url = "https://github.com/nelhage/rules_boost/archive/652b21e35e4eeed5579e696da0facbe8dba52b1f.tar.gz", + sha256 = "c1b8b2adc3b4201683cf94dda7eef3fc0f4f4c0ea5caa3ed3feffe07e1fb5b15", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:rules_boost-windows-linkopts.patch", + ], + ) + + auto_http_archive( + name = "rules_jvm_external", + url = "https://github.com/bazelbuild/rules_jvm_external/archive/2.10.tar.gz", + sha256 = "5c1b22eab26807d5286ada7392d796cbc8425d3ef9a57d114b79c5f8ef8aca7c", + ) + + http_archive( + name = "io_opencensus_proto", + strip_prefix = "opencensus-proto-0.3.0/src", + urls = ["https://github.com/census-instrumentation/opencensus-proto/archive/v0.3.0.tar.gz"], + sha256 = "b7e13f0b4259e80c3070b583c2f39e53153085a6918718b1c710caf7037572b0", + ) + + auto_http_archive( + name = "com_github_redis_hiredis", + build_file = "@com_github_ray_project_ray//bazel:BUILD.hiredis", + url = "https://github.com/redis/hiredis/archive/392de5d7f97353485df1237872cb682842e8d83f.tar.gz", + sha256 = "2101650d39a8f13293f263e9da242d2c6dee0cda08d343b2939ffe3d95cf3b8b", + patches = [ + "@com_github_ray_project_ray//thirdparty/patches:hiredis-windows-msvc.patch", + ], + ) \ No newline at end of file diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel new file mode 100644 index 00000000..07ce8c7b --- /dev/null +++ b/streaming/java/BUILD.bazel @@ -0,0 +1,268 @@ +load("//bazel:ray.bzl", "define_java_module") +load("//bazel:ray.bzl", "native_java_binary") +load("//bazel:ray.bzl", "native_java_library") +load("@rules_proto_grpc//java:defs.bzl", "java_proto_compile") + +exports_files([ + "testng.xml", +]) + +all_modules = [ + "streaming-state", + "streaming-api", + "streaming-runtime", +] + +java_import( + name = "all_modules", + jars = [ + "libio_ray_ray_" + module + ".jar" + for module in all_modules + ] + [ + "libio_ray_ray_" + module + "-src.jar" + for module in all_modules + ] + [ + "all_streaming_tests_deploy.jar", + "all_streaming_tests_deploy-src.jar", + ], + deps = [ + ":io_ray_ray_" + module + for module in all_modules + ] + [ + ":all_streaming_tests", + ], +) + +define_java_module( + name = "streaming-api", + define_test_lib = True, + test_deps = [ + "//java:io_ray_ray_api", + ":io_ray_ray_streaming-state", + ":io_ray_ray_streaming-api", + "@maven//:com_google_guava_guava", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_slf4j_slf4j_api", + "@maven//:org_testng_testng", + "@maven//:org_yaml_snakeyaml", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", + ], + visibility = ["//visibility:public"], + deps = [ + ":io_ray_ray_streaming-state", + "//java:io_ray_ray_api", + "//java:io_ray_ray_runtime", + "@maven//:com_google_guava_guava", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", + "@ray_streaming_maven//:com_google_code_gson_gson", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", + ], +) + +define_java_module( + name = "streaming-state", + define_test_lib = True, + test_deps = [ + ":io_ray_ray_streaming-state", + "@maven//:com_google_guava_guava", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_slf4j_slf4j_api", + "@maven//:org_testng_testng", + "@maven//:de_ruedigermoeller_fst", + "@maven//:org_yaml_snakeyaml", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", + "@ray_streaming_maven//:org_mockito_mockito_all", + ], + visibility = ["//visibility:public"], + deps = [ + "@maven//:com_google_guava_guava", + "@maven//:de_ruedigermoeller_fst", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", + ], +) + +native_java_library("streaming-runtime", "streaming_java", "//streaming:libstreaming_java.so") + +filegroup( + name = "java_native_deps", + srcs = [":streaming_java"], +) + +define_java_module( + name = "streaming-runtime", + additional_resources = [ + ":java_native_deps", + ], + additional_srcs = [ + ":all_java_proto", + ], + define_test_lib = True, + exclude_srcs = [ + "streaming-runtime/src/main/java/io/ray/streaming/runtime/generated/*.java", + ], + test_deps = [ + "//java:io_ray_ray_api", + "//java:io_ray_ray_runtime", + ":io_ray_ray_streaming-state", + ":io_ray_ray_streaming-api", + ":io_ray_ray_streaming-runtime", + "@maven//:com_google_guava_guava", + "@maven//:de_ruedigermoeller_fst", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_slf4j_slf4j_api", + "@maven//:org_testng_testng", + "@maven//:org_yaml_snakeyaml", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", + "@ray_streaming_maven//:org_aeonbits_owner_owner", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", + "@ray_streaming_maven//:org_mockito_mockito_all", + "@ray_streaming_maven//:org_powermock_powermock_api_mockito", + "@ray_streaming_maven//:org_powermock_powermock_module_testng", + ], + visibility = ["//visibility:public"], + deps = [ + ":io_ray_ray_streaming-api", + ":io_ray_ray_streaming-state", + "//java:io_ray_ray_api", + "//java:io_ray_ray_runtime", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:commons_io_commons_io", + "@maven//:de_ruedigermoeller_fst", + "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_apache_logging_log4j_log4j_api", + "@maven//:org_apache_logging_log4j_log4j_core", + "@maven//:org_apache_logging_log4j_log4j_slf4j_impl", + "@maven//:org_msgpack_msgpack_core", + "@maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", + "@ray_streaming_maven//:commons_collections_commons_collections", + "@ray_streaming_maven//:org_aeonbits_owner_owner", + ], +) + +java_binary( + name = "all_streaming_tests", + args = ["streaming/java/testng.xml"], + data = ["testng.xml"], + main_class = "org.testng.TestNG", + runtime_deps = [ + ":io_ray_ray_streaming-api_test", + ":io_ray_ray_streaming-runtime", + ":io_ray_ray_streaming-runtime_test", + ":io_ray_ray_streaming-state", + "//java:io_ray_ray_runtime", + "@maven//:org_testng_testng", + "@maven//:org_yaml_snakeyaml", + "@ray_streaming_maven//:org_mockito_mockito_all", + "@ray_streaming_maven//:org_powermock_powermock_api_mockito", + "@ray_streaming_maven//:org_powermock_powermock_module_testng", + ], +) + +# proto buffer +java_proto_compile( + name = "streaming_java_proto", + deps = ["//streaming:streaming_proto"], +) + +java_proto_compile( + name = "remote_call_java_proto", + deps = ["//streaming:remote_call_proto"], +) + +filegroup( + name = "all_java_proto", + srcs = [ + ":remote_call_java_proto", + ":streaming_java_proto", + ], +) + +genrule( + name = "copy_pom_file", + srcs = [ + "//streaming/java:io_ray_ray_" + module + "_pom" + for module in all_modules + ], + outs = ["copy_pom_file.out"], + cmd = """ + WORK_DIR="$$(pwd)" + cp -f $(location //streaming/java:io_ray_ray_streaming-api_pom) "$$WORK_DIR/streaming/java/streaming-api/pom.xml" + cp -f $(location //streaming/java:io_ray_ray_streaming-runtime_pom) "$$WORK_DIR/streaming/java/streaming-runtime/pom.xml" + cp -f $(location //streaming/java:io_ray_ray_streaming-state_pom) "$$WORK_DIR/streaming/java/streaming-state/pom.xml" + date > $@ + """, + local = 1, + tags = ["no-cache"], +) + +genrule( + name = "cp_java_generated", + srcs = [ + ":all_java_proto", + ":copy_pom_file", + ], + outs = ["cp_java_generated.out"], + cmd = """ + WORK_DIR="$$(pwd)" + GENERATED_DIR="$$WORK_DIR/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/generated" + rm -rf "$$GENERATED_DIR" + mkdir -p "$$GENERATED_DIR" + # Copy protobuf-generated files. + for f in $(locations //streaming/java:all_java_proto); do + unzip -q -o "$$f" -x META-INF/MANIFEST.MF -d "$$WORK_DIR/streaming/java/streaming-runtime/src/main/java" + done + # remove third party protobuf + rm -rf $$WORK_DIR/streaming/java/streaming-runtime/src/main/java/com/google/protobuf/ + date > $@ + """, + local = 1, + tags = ["no-cache"], +) + +# Generates the dependencies needed by maven. +genrule( + name = "gen_maven_deps", + srcs = [ + ":java_native_deps", + ":cp_java_generated", + ], + outs = ["gen_maven_deps.out"], + cmd = """ + WORK_DIR="$${PWD}" + # Copy native dependencies. + OS_NAME="" + case "$${OSTYPE}" in + linux*) OS_NAME="linux";; + darwin*) OS_NAME="darwin";; + *) echo "$${OSTYPE} is not supported currently"; exit 1;; + esac + NATIVE_DEPS_DIR="$$WORK_DIR/streaming/java/streaming-runtime/native_dependencies/native/$$OS_NAME" + rm -rf "$$NATIVE_DEPS_DIR" + mkdir -p "$$NATIVE_DEPS_DIR" + for f in $(locations //streaming/java:java_native_deps); do + chmod +w "$$f" + cp "$$f" "$$NATIVE_DEPS_DIR" + done + date > $@ + """, + local = 1, + tags = ["no-cache"], + visibility = ["//visibility:public"], +) diff --git a/streaming/java/checkstyle-suppressions.xml b/streaming/java/checkstyle-suppressions.xml new file mode 100644 index 00000000..1d86cdba --- /dev/null +++ b/streaming/java/checkstyle-suppressions.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl new file mode 100644 index 00000000..72d899ad --- /dev/null +++ b/streaming/java/dependencies.bzl @@ -0,0 +1,22 @@ +load("@rules_jvm_external//:defs.bzl", "maven_install") + +def gen_streaming_java_deps(): + maven_install( + name = "ray_streaming_maven", + artifacts = [ + "com.google.code.findbugs:jsr305:3.0.2", + "com.google.code.gson:gson:2.8.5", + "com.github.davidmoten:flatbuffers-java:1.9.0.1", + "org.apache.commons:commons-lang3:3.4", + "org.aeonbits.owner:owner:1.0.10", + "org.mockito:mockito-all:1.10.19", + "org.apache.commons:commons-lang3:3.3.2", + "org.mockito:mockito-all:1.10.19", + "org.powermock:powermock-module-testng:1.6.6", + "org.powermock:powermock-api-mockito:1.6.6", + "commons-collections:commons-collections:3.2.2", + ], + repositories = [ + "https://repo1.maven.org/maven2/", + ], + ) diff --git a/streaming/java/generate_jni_header_files.sh b/streaming/java/generate_jni_header_files.sh new file mode 100755 index 00000000..5ce3cb7a --- /dev/null +++ b/streaming/java/generate_jni_header_files.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +set -e +set -x + +cd "$(dirname "$0")" + +bazel build all_streaming_tests_deploy.jar + +function generate_one() +{ + file=${1//./_}.h + javah -classpath ../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar "$1" + + # prepend licence first + cat < ../src/lib/java/"$file" +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +EOF + # then append the generated header file + cat "$file" >> ../src/lib/java/"$file" + rm -f "$file" +} + +generate_one io.ray.streaming.runtime.transfer.channel.ChannelId +generate_one io.ray.streaming.runtime.transfer.DataReader +generate_one io.ray.streaming.runtime.transfer.DataWriter +generate_one io.ray.streaming.runtime.transfer.TransferHandler + +rm -f io_ray_streaming_*.h diff --git a/streaming/java/streaming-api/pom_template.xml b/streaming/java/streaming-api/pom_template.xml new file mode 100644 index 00000000..1bbe45cf --- /dev/null +++ b/streaming/java/streaming-api/pom_template.xml @@ -0,0 +1,37 @@ + + {auto_gen_header} + + + ray-streaming + io.ray + 2.0.0-SNAPSHOT + + 4.0.0 + + streaming-api + ray streaming api + ray streaming api + + jar + + + + io.ray + ray-api + ${project.version} + + + io.ray + ray-runtime + ${project.version} + + + io.ray + streaming-state + ${project.version} + + {generated_bzl_deps} + + diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/Language.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/Language.java new file mode 100644 index 00000000..e580b34f --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/Language.java @@ -0,0 +1,6 @@ +package io.ray.streaming.api; + +public enum Language { + JAVA, + PYTHON +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/CollectionCollector.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/CollectionCollector.java new file mode 100644 index 00000000..04504a84 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/CollectionCollector.java @@ -0,0 +1,25 @@ +package io.ray.streaming.api.collector; + +import io.ray.streaming.message.Record; +import java.util.List; + +/** + * Combination of multiple collectors. + * + * @param The type of output data. + */ +public class CollectionCollector implements Collector { + + private List collectorList; + + public CollectionCollector(List collectorList) { + this.collectorList = collectorList; + } + + @Override + public void collect(T value) { + for (Collector collector : collectorList) { + collector.collect(new Record(value)); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/Collector.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/Collector.java new file mode 100644 index 00000000..ef7c1635 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/collector/Collector.java @@ -0,0 +1,12 @@ +package io.ray.streaming.api.collector; + +/** + * The collector that collects data from an upstream operator, and emits data to downstream + * operators. + * + * @param Type of the data to collect. + */ +public interface Collector { + + void collect(T value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java new file mode 100644 index 00000000..c303b499 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java @@ -0,0 +1,28 @@ +package io.ray.streaming.api.context; + +import com.google.common.base.Preconditions; +import io.ray.api.Ray; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class ClusterStarter { + + private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class); + + static synchronized void startCluster(boolean isLocal) { + Preconditions.checkArgument(!Ray.isInitialized()); + if (!isLocal) { + System.setProperty("ray.run-mode", "CLUSTER"); + } else { + System.setProperty("ray.run-mode", "SINGLE_PROCESS"); + } + + Ray.init(); + } + + public static synchronized void stopCluster() { + // Disconnect to the cluster. + Ray.shutdown(); + System.clearProperty("ray.run-mode"); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java new file mode 100644 index 00000000..637e98f0 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/RuntimeContext.java @@ -0,0 +1,42 @@ +package io.ray.streaming.api.context; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import java.util.Map; + +/** Encapsulate the runtime information of a streaming task. */ +public interface RuntimeContext { + + int getTaskId(); + + int getTaskIndex(); + + int getParallelism(); + + /** Returns config of current function */ + Map getConfig(); + + /** Returns config of the job */ + Map getJobConfig(); + + Long getCheckpointId(); + + void setCheckpointId(long checkpointId); + + void setCurrentKey(Object key); + + KeyStateBackend getKeyStateBackend(); + + void setKeyStateBackend(KeyStateBackend keyStateBackend); + + ValueState getValueState(ValueStateDescriptor stateDescriptor); + + ListState getListState(ListStateDescriptor stateDescriptor); + + MapState getMapState(MapStateDescriptor stateDescriptor); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java new file mode 100644 index 00000000..caed688d --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java @@ -0,0 +1,98 @@ +package io.ray.streaming.api.context; + +import com.google.common.base.Preconditions; +import io.ray.api.Ray; +import io.ray.streaming.api.stream.StreamSink; +import io.ray.streaming.client.JobClient; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.jobgraph.JobGraphOptimizer; +import io.ray.streaming.util.Config; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Encapsulate the context information of a streaming Job. */ +public class StreamingContext implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class); + + private transient AtomicInteger idGenerator; + + /** The sinks of this streaming job. */ + private List streamSinks; + + /** The user custom streaming job configuration. */ + private Map jobConfig; + + /** The logic plan. */ + private JobGraph jobGraph; + + private StreamingContext() { + this.idGenerator = new AtomicInteger(0); + this.streamSinks = new ArrayList<>(); + this.jobConfig = new HashMap<>(); + } + + public static StreamingContext buildContext() { + return new StreamingContext(); + } + + /** Construct job DAG, and execute the job. */ + public void execute(String jobName) { + JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this.streamSinks, jobName); + JobGraph originalJobGraph = jobGraphBuilder.build(); + this.jobGraph = new JobGraphOptimizer(originalJobGraph).optimize(); + jobGraph.printJobGraph(); + LOG.info("JobGraph digraph\n{}", jobGraph.generateDigraph()); + + if (!Ray.isInitialized()) { + if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) { + ClusterStarter.startCluster(true); + LOG.info("Created local cluster for job {}.", jobName); + } else { + ClusterStarter.startCluster(false); + LOG.info("Created multi process cluster for job {}.", jobName); + } + Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop)); + } else { + LOG.info("Reuse existing cluster."); + } + + ServiceLoader serviceLoader = ServiceLoader.load(JobClient.class); + Iterator iterator = serviceLoader.iterator(); + Preconditions.checkArgument( + iterator.hasNext(), "No JobClient implementation has been provided."); + JobClient jobClient = iterator.next(); + jobClient.submit(jobGraph, jobConfig); + } + + public int generateId() { + return this.idGenerator.incrementAndGet(); + } + + public void addSink(StreamSink streamSink) { + streamSinks.add(streamSink); + } + + public List getStreamSinks() { + return streamSinks; + } + + public void withConfig(Map jobConfig) { + this.jobConfig = jobConfig; + } + + public void stop() { + if (Ray.isInitialized()) { + ClusterStarter.stopCluster(); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java new file mode 100644 index 00000000..c12bdf87 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java @@ -0,0 +1,28 @@ +package io.ray.streaming.api.function; + +import java.io.Serializable; + +/** Interface of streaming functions. */ +public interface Function extends Serializable { + + /** + * This method will be called periodically by framework, you should return a a serializable object + * which represents function state, framework will help you to serialize this object, save it to + * storage, and load it back when in fail-over through. {@link + * Function#loadCheckpoint(Serializable)}. + * + * @return A serializable object which represents function state. + */ + default Serializable saveCheckpoint() { + return null; + } + + /** + * This method will be called by framework when a worker died and been restarted. We will pass the + * last object you returned in {@link Function#saveCheckpoint()} when doing checkpoint, you are + * responsible to load this object back to you function. + * + * @param checkpointObject the last object you returned in {@link Function#saveCheckpoint()} + */ + default void loadCheckpoint(Serializable checkpointObject) {} +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java new file mode 100644 index 00000000..acf09b75 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/RichFunction.java @@ -0,0 +1,23 @@ +package io.ray.streaming.api.function; + +import io.ray.streaming.api.context.RuntimeContext; + +/** + * An interface for all user-defined functions to define the life cycle methods of the functions, + * and access the task context where the functions get executed. + */ +public interface RichFunction extends Function { + + /** + * Initialization method for user function which called before the first call to the user + * function. + * + * @param runtimeContext runtime context + */ + void open(RuntimeContext runtimeContext); + + /** + * Tear-down method for the user function which called after the last call to the user function. + */ + void close(); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/AggregateFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/AggregateFunction.java new file mode 100644 index 00000000..53ef5144 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/AggregateFunction.java @@ -0,0 +1,23 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of aggregate functions. + * + * @param Type of the input data. + * @param Type of the intermediate data. + * @param Type of the output data. + */ +public interface AggregateFunction extends Function { + + A createAccumulator(); + + void add(I value, A accumulator); + + O getResult(A accumulator); + + A merge(A a, A b); + + void retract(A acc, I value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FilterFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FilterFunction.java new file mode 100644 index 00000000..d60e335a --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FilterFunction.java @@ -0,0 +1,21 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * A filter function is a predicate applied individually to each record. The predicate decides + * whether to keep the element, or to discard it. + * + * @param type of the input data. + */ +@FunctionalInterface +public interface FilterFunction extends Function { + + /** + * The filter function that evaluates the predicate. + * + * @param value The value to be filtered. + * @return True for values that should be retained, false for values to be filtered out. + */ + boolean filter(T value) throws Exception; +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FlatMapFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FlatMapFunction.java new file mode 100644 index 00000000..fe648cb5 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/FlatMapFunction.java @@ -0,0 +1,16 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.function.Function; + +/** + * Interface of flat-map functions. + * + * @param Type of the input data. + * @param Type of the output data. + */ +@FunctionalInterface +public interface FlatMapFunction extends Function { + + void flatMap(T value, Collector collector); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/JoinFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/JoinFunction.java new file mode 100644 index 00000000..a45b6f46 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/JoinFunction.java @@ -0,0 +1,16 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of join functions. + * + * @param Type of the left input data. + * @param Type of the right input data. + * @param Type of the output data. + */ +@FunctionalInterface +public interface JoinFunction extends Function { + + R join(T left, O right); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/KeyFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/KeyFunction.java new file mode 100644 index 00000000..c1cd608c --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/KeyFunction.java @@ -0,0 +1,15 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of key-by functions. + * + * @param Type of the input data. + * @param Type of the key-by field. + */ +@FunctionalInterface +public interface KeyFunction extends Function { + + K keyBy(T value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/MapFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/MapFunction.java new file mode 100644 index 00000000..ee2c76c8 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/MapFunction.java @@ -0,0 +1,15 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of map functions. + * + * @param type of the input data. + * @param type of the output data. + */ +@FunctionalInterface +public interface MapFunction extends Function { + + R map(T value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ProcessFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ProcessFunction.java new file mode 100644 index 00000000..eab9e711 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ProcessFunction.java @@ -0,0 +1,14 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of process functions. + * + * @param Type of the input data. + */ +@FunctionalInterface +public interface ProcessFunction extends Function { + + void process(T value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ReduceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ReduceFunction.java new file mode 100644 index 00000000..350ea926 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/ReduceFunction.java @@ -0,0 +1,14 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of reduce functions. + * + * @param Type of the input data. + */ +@FunctionalInterface +public interface ReduceFunction extends Function { + + T reduce(T oldValue, T newValue); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SinkFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SinkFunction.java new file mode 100644 index 00000000..31a156bc --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SinkFunction.java @@ -0,0 +1,14 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of sink functions. + * + * @param Type of the sink data. + */ +@FunctionalInterface +public interface SinkFunction extends Function { + + void sink(T value); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java new file mode 100644 index 00000000..92c227b7 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java @@ -0,0 +1,22 @@ +package io.ray.streaming.api.function.impl; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of Source functions. + * + * @param Type of the data output by the source. + */ +public interface SourceFunction extends Function { + + void init(int parallelism, int index); + + void fetch(SourceContext ctx) throws Exception; + + void close(); + + interface SourceContext { + + void collect(T element) throws Exception; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java new file mode 100644 index 00000000..a1bb3841 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java @@ -0,0 +1,36 @@ +package io.ray.streaming.api.function.internal; + +import io.ray.streaming.api.function.impl.SourceFunction; +import java.util.Collection; + +/** + * The SourceFunction that fetch data from a Java Collection object. + * + * @param Type of the data output by the source. + */ +public class CollectionSourceFunction implements SourceFunction { + + private Collection values; + private boolean finished = false; + + public CollectionSourceFunction(Collection values) { + this.values = values; + } + + @Override + public void init(int totalParallel, int currentIndex) {} + + @Override + public void fetch(SourceContext ctx) throws Exception { + if (finished) { + return; + } + for (T value : values) { + ctx.collect(value); + } + finished = true; + } + + @Override + public void close() {} +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java new file mode 100644 index 00000000..784176fc --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/Functions.java @@ -0,0 +1,40 @@ +package io.ray.streaming.api.function.internal; + +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.RichFunction; + +/** A util class for {@link Function} */ +public class Functions { + + private static class DefaultRichFunction implements RichFunction { + + private final Function function; + + private DefaultRichFunction(Function function) { + this.function = function; + } + + @Override + public void open(RuntimeContext runtimeContext) {} + + @Override + public void close() {} + + public Function getFunction() { + return function; + } + } + + public static RichFunction wrap(Function function) { + if (function instanceof RichFunction) { + return (RichFunction) function; + } else { + return new DefaultRichFunction(function); + } + } + + public static RichFunction emptyFunction() { + return new DefaultRichFunction(null); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/Partition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/Partition.java new file mode 100644 index 00000000..80e9d927 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/Partition.java @@ -0,0 +1,22 @@ +package io.ray.streaming.api.partition; + +import io.ray.streaming.api.function.Function; + +/** + * Interface of the partitioning strategy. + * + * @param Type of the input data. + */ +@FunctionalInterface +public interface Partition extends Function { + + /** + * Given a record and downstream partitions, determine which partition(s) should receive the + * record. + * + * @param record The record. + * @param numPartition num of partitions + * @return IDs of the downstream partitions that should receive the record. + */ + int[] partition(T record, int numPartition); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/BroadcastPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/BroadcastPartition.java new file mode 100644 index 00000000..c34a7fb2 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/BroadcastPartition.java @@ -0,0 +1,20 @@ +package io.ray.streaming.api.partition.impl; + +import io.ray.streaming.api.partition.Partition; +import java.util.stream.IntStream; + +/** Broadcast the record to all downstream partitions. */ +public class BroadcastPartition implements Partition { + + private int[] partitions = new int[0]; + + public BroadcastPartition() {} + + @Override + public int[] partition(T value, int numPartition) { + if (partitions.length != numPartition) { + partitions = IntStream.rangeClosed(0, numPartition - 1).toArray(); + } + return partitions; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java new file mode 100644 index 00000000..9b747380 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/ForwardPartition.java @@ -0,0 +1,20 @@ +package io.ray.streaming.api.partition.impl; + +import io.ray.streaming.api.partition.Partition; + +/** + * Default partition for operator if the operator can be chained with succeeding operators. + * Partition will be set to {@link RoundRobinPartition} if the operator can't be chiained with + * succeeding operators. + * + * @param Type of the input record. + */ +public class ForwardPartition implements Partition { + + private int[] partitions = new int[] {0}; + + @Override + public int[] partition(T record, int numPartition) { + return partitions; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/KeyPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/KeyPartition.java new file mode 100644 index 00000000..1de18168 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/KeyPartition.java @@ -0,0 +1,21 @@ +package io.ray.streaming.api.partition.impl; + +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.message.KeyRecord; + +/** + * Partition the record by the key. + * + * @param Type of the partition key. + * @param Type of the input record. + */ +public class KeyPartition implements Partition> { + + private int[] partitions = new int[1]; + + @Override + public int[] partition(KeyRecord keyRecord, int numPartition) { + partitions[0] = Math.abs(keyRecord.getKey().hashCode() % numPartition); + return partitions; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/RoundRobinPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/RoundRobinPartition.java new file mode 100644 index 00000000..01e624ce --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/partition/impl/RoundRobinPartition.java @@ -0,0 +1,25 @@ +package io.ray.streaming.api.partition.impl; + +import io.ray.streaming.api.partition.Partition; + +/** + * Partition record to downstream tasks in a round-robin matter. + * + * @param Type of the input record. + */ +public class RoundRobinPartition implements Partition { + + private int seq; + private int[] partitions = new int[1]; + + public RoundRobinPartition() { + this.seq = 0; + } + + @Override + public int[] partition(T value, int numPartition) { + seq = (seq + 1) % numPartition; + partitions[0] = seq; + return partitions; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java new file mode 100644 index 00000000..999057d5 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java @@ -0,0 +1,202 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.FilterFunction; +import io.ray.streaming.api.function.impl.FlatMapFunction; +import io.ray.streaming.api.function.impl.KeyFunction; +import io.ray.streaming.api.function.impl.MapFunction; +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.partition.impl.BroadcastPartition; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.impl.FilterOperator; +import io.ray.streaming.operator.impl.FlatMapOperator; +import io.ray.streaming.operator.impl.KeyByOperator; +import io.ray.streaming.operator.impl.MapOperator; +import io.ray.streaming.operator.impl.SinkOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Represents a stream of data. + * + *

This class defines all the streaming operations. + * + * @param Type of data in the stream. + */ +public class DataStream extends Stream, T> { + + public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) { + super(streamingContext, streamOperator); + } + + public DataStream( + StreamingContext streamingContext, StreamOperator streamOperator, Partition partition) { + super(streamingContext, streamOperator, partition); + } + + public DataStream(DataStream input, StreamOperator streamOperator) { + super(input, streamOperator); + } + + public DataStream( + DataStream input, StreamOperator streamOperator, Partition partition) { + super(input, streamOperator, partition); + } + + /** + * Create a java stream that reference passed python stream. Changes in new stream will be + * reflected in referenced stream and vice versa + */ + public DataStream(PythonDataStream referencedStream) { + super(referencedStream); + } + + /** + * Apply a map function to this stream. + * + * @param mapFunction The map function. + * @param Type of data returned by the map function. + * @return A new DataStream. + */ + public DataStream map(MapFunction mapFunction) { + return new DataStream<>(this, new MapOperator<>(mapFunction)); + } + + /** + * Apply a flat-map function to this stream. + * + * @param flatMapFunction The FlatMapFunction + * @param Type of data returned by the flatmap function. + * @return A new DataStream + */ + public DataStream flatMap(FlatMapFunction flatMapFunction) { + return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction)); + } + + public DataStream filter(FilterFunction filterFunction) { + return new DataStream<>(this, new FilterOperator<>(filterFunction)); + } + + /** + * Apply union transformations to this stream by merging {@link DataStream} outputs of the same + * type with each other. + * + * @param stream The DataStream to union output with. + * @param others The other DataStreams to union output with. + * @return A new UnionStream. + */ + @SafeVarargs + public final DataStream union(DataStream stream, DataStream... others) { + List> streams = new ArrayList<>(); + streams.add(stream); + streams.addAll(Arrays.asList(others)); + return union(streams); + } + + /** + * Apply union transformations to this stream by merging {@link DataStream} outputs of the same + * type with each other. + * + * @param streams The DataStreams to union output with. + * @return A new UnionStream. + */ + public final DataStream union(List> streams) { + if (this instanceof UnionStream) { + UnionStream unionStream = (UnionStream) this; + streams.forEach(unionStream::addStream); + return unionStream; + } else { + return new UnionStream<>(this, streams); + } + } + + /** + * Apply a join transformation to this stream, with another stream. + * + * @param other Another stream. + * @param The type of the other stream data. + * @param The type of the data in the joined stream. + * @return A new JoinStream. + */ + public JoinStream join(DataStream other) { + return new JoinStream<>(this, other); + } + + public DataStream process() { + // TODO(zhenxuanpan): Need to add processFunction. + return new DataStream(this, null); + } + + /** + * Apply a sink function and get a StreamSink. + * + * @param sinkFunction The sink function. + * @return A new StreamSink. + */ + public DataStreamSink sink(SinkFunction sinkFunction) { + return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction)); + } + + /** + * Apply a key-by function to this stream. + * + * @param keyFunction the key function. + * @param The type of the key. + * @return A new KeyDataStream. + */ + public KeyDataStream keyBy(KeyFunction keyFunction) { + checkPartitionCall(); + return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction)); + } + + /** + * Apply broadcast to this stream. + * + * @return This stream. + */ + public DataStream broadcast() { + checkPartitionCall(); + return setPartition(new BroadcastPartition<>()); + } + + /** + * Apply a partition to this stream. + * + * @param partition The partitioning strategy. + * @return This stream. + */ + public DataStream partitionBy(Partition partition) { + checkPartitionCall(); + return setPartition(partition); + } + + /** + * If parent stream is a python stream, we can't call partition related methods in the java + * stream. + */ + private void checkPartitionCall() { + if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) { + throw new RuntimeException( + "Partition related methods can't be called on a " + + "java stream if parent stream is a python stream."); + } + } + + /** + * Convert this stream as a python stream. The converted stream and this stream are the same + * logical stream, which has same stream id. Changes in converted stream will be reflected in this + * stream and vice versa. + */ + public PythonDataStream asPythonStream() { + return new PythonDataStream(this); + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java new file mode 100644 index 00000000..e58bb420 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java @@ -0,0 +1,22 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.api.Language; +import io.ray.streaming.operator.impl.SinkOperator; + +/** + * Represents a sink of the DataStream. + * + * @param Type of the input data of this sink. + */ +public class DataStreamSink extends StreamSink { + + public DataStreamSink(DataStream input, SinkOperator sinkOperator) { + super(input, sinkOperator); + getStreamingContext().addSink(this); + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java new file mode 100644 index 00000000..53dd2a09 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java @@ -0,0 +1,37 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.SourceFunction; +import io.ray.streaming.api.function.internal.CollectionSourceFunction; +import io.ray.streaming.operator.impl.SourceOperatorImpl; +import java.util.Collection; + +/** + * Represents a source of the DataStream. + * + * @param The type of StreamSource data. + */ +public class DataStreamSource extends DataStream implements StreamSource { + + private DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) { + super(streamingContext, new SourceOperatorImpl<>(sourceFunction)); + } + + public static DataStreamSource fromSource( + StreamingContext context, SourceFunction sourceFunction) { + return new DataStreamSource<>(context, sourceFunction); + } + + /** + * Build a DataStreamSource source from a collection. + * + * @param context Stream context. + * @param values A collection of values. + * @param The type of source data. + * @return A DataStreamSource. + */ + public static DataStreamSource fromCollection( + StreamingContext context, Collection values) { + return new DataStreamSource<>(context, new CollectionSourceFunction<>(values)); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java new file mode 100644 index 00000000..191c0235 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/JoinStream.java @@ -0,0 +1,80 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.api.function.impl.JoinFunction; +import io.ray.streaming.api.function.impl.KeyFunction; +import io.ray.streaming.operator.impl.JoinOperator; +import java.io.Serializable; + +/** + * Represents a DataStream of two joined DataStream. + * + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the joined stream. + */ +public class JoinStream extends DataStream { + + private final DataStream rightStream; + + public JoinStream(DataStream leftStream, DataStream rightStream) { + super(leftStream, new JoinOperator<>()); + this.rightStream = rightStream; + } + + public DataStream getRightStream() { + return rightStream; + } + + /** Apply key-by to the left join stream. */ + public Where where(KeyFunction keyFunction) { + return new Where<>(this, keyFunction); + } + + /** + * Where clause of the join transformation. + * + * @param Type of the join key. + */ + class Where implements Serializable { + + private JoinStream joinStream; + private KeyFunction leftKeyByFunction; + + Where(JoinStream joinStream, KeyFunction leftKeyByFunction) { + this.joinStream = joinStream; + this.leftKeyByFunction = leftKeyByFunction; + } + + public Equal equalTo(KeyFunction rightKeyFunction) { + return new Equal<>(joinStream, leftKeyByFunction, rightKeyFunction); + } + } + + /** + * Equal clause of the join transformation. + * + * @param Type of the join key. + */ + class Equal implements Serializable { + + private JoinStream joinStream; + private KeyFunction leftKeyByFunction; + private KeyFunction rightKeyByFunction; + + Equal( + JoinStream joinStream, + KeyFunction leftKeyByFunction, + KeyFunction rightKeyByFunction) { + this.joinStream = joinStream; + this.leftKeyByFunction = leftKeyByFunction; + this.rightKeyByFunction = rightKeyByFunction; + } + + @SuppressWarnings("unchecked") + public DataStream with(JoinFunction joinFunction) { + JoinOperator joinOperator = (JoinOperator) joinStream.getOperator(); + joinOperator.setFunction(joinFunction); + return (DataStream) joinStream; + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java new file mode 100644 index 00000000..c50b2326 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java @@ -0,0 +1,63 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.api.function.impl.AggregateFunction; +import io.ray.streaming.api.function.impl.ReduceFunction; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.partition.impl.KeyPartition; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.impl.ReduceOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonKeyDataStream; + +/** + * Represents a DataStream returned by a key-by operation. + * + * @param Type of the key. + * @param Type of the data. + */ +@SuppressWarnings("unchecked") +public class KeyDataStream extends DataStream { + + public KeyDataStream(DataStream input, StreamOperator streamOperator) { + super(input, streamOperator, (Partition) new KeyPartition()); + } + + /** + * Create a java stream that reference passed python stream. Changes in new stream will be + * reflected in referenced stream and vice versa + */ + public KeyDataStream(PythonDataStream referencedStream) { + super(referencedStream); + } + + /** + * Apply a reduce function to this stream. + * + * @param reduceFunction The reduce function. + * @return A new DataStream. + */ + public DataStream reduce(ReduceFunction reduceFunction) { + return new DataStream<>(this, new ReduceOperator(reduceFunction)); + } + + /** + * Apply an aggregate Function to this stream. + * + * @param aggregateFunction The aggregate function + * @param The type of aggregated intermediate data. + * @param The type of result data. + * @return A new DataStream. + */ + public DataStream aggregate(AggregateFunction aggregateFunction) { + return new DataStream<>(this, null); + } + + /** + * Convert this stream as a python stream. The converted stream and this stream are the same + * logical stream, which has same stream id. Changes in converted stream will be reflected in this + * stream and vice versa. + */ + public PythonKeyDataStream asPythonStream() { + return new PythonKeyDataStream(this); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java new file mode 100644 index 00000000..cd5d538f --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java @@ -0,0 +1,192 @@ +package io.ray.streaming.api.stream; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.Language; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.partition.impl.ForwardPartition; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.Operator; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.python.PythonPartition; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * Abstract base class of all stream types. + * + * @param Type of stream class + * @param Type of the data in the stream. + */ +public abstract class Stream, T> implements Serializable { + + private final int id; + private final StreamingContext streamingContext; + private final Stream inputStream; + private final StreamOperator operator; + private int parallelism = 1; + private Map config = new HashMap<>(); + private Partition partition; + private Stream originalStream; + + public Stream(StreamingContext streamingContext, StreamOperator streamOperator) { + this(streamingContext, null, streamOperator, getForwardPartition(streamOperator)); + } + + public Stream( + StreamingContext streamingContext, StreamOperator streamOperator, Partition partition) { + this(streamingContext, null, streamOperator, partition); + } + + public Stream(Stream inputStream, StreamOperator streamOperator) { + this( + inputStream.getStreamingContext(), + inputStream, + streamOperator, + getForwardPartition(streamOperator)); + } + + public Stream(Stream inputStream, StreamOperator streamOperator, Partition partition) { + this(inputStream.getStreamingContext(), inputStream, streamOperator, partition); + } + + protected Stream( + StreamingContext streamingContext, + Stream inputStream, + StreamOperator streamOperator, + Partition partition) { + this.streamingContext = streamingContext; + this.inputStream = inputStream; + this.operator = streamOperator; + this.partition = partition; + this.id = streamingContext.generateId(); + if (inputStream != null) { + this.parallelism = inputStream.getParallelism(); + } + } + + /** + * Create a proxy stream of original stream. Changes in new stream will be reflected in original + * stream and vice versa + */ + protected Stream(Stream originalStream) { + this.originalStream = originalStream; + this.id = originalStream.getId(); + this.streamingContext = originalStream.getStreamingContext(); + this.inputStream = originalStream.getInputStream(); + this.operator = originalStream.getOperator(); + Preconditions.checkNotNull(operator); + } + + @SuppressWarnings("unchecked") + private static Partition getForwardPartition(Operator operator) { + switch (operator.getLanguage()) { + case PYTHON: + return (Partition) PythonPartition.ForwardPartition; + case JAVA: + return new ForwardPartition<>(); + default: + throw new UnsupportedOperationException("Unsupported language " + operator.getLanguage()); + } + } + + public int getId() { + return id; + } + + public StreamingContext getStreamingContext() { + return streamingContext; + } + + public Stream getInputStream() { + return inputStream; + } + + public StreamOperator getOperator() { + return operator; + } + + @SuppressWarnings("unchecked") + private S self() { + return (S) this; + } + + public int getParallelism() { + return originalStream != null ? originalStream.getParallelism() : parallelism; + } + + public S setParallelism(int parallelism) { + if (originalStream != null) { + originalStream.setParallelism(parallelism); + } else { + this.parallelism = parallelism; + } + return self(); + } + + @SuppressWarnings("unchecked") + public Partition getPartition() { + return originalStream != null ? originalStream.getPartition() : partition; + } + + @SuppressWarnings("unchecked") + protected S setPartition(Partition partition) { + if (originalStream != null) { + originalStream.setPartition(partition); + } else { + this.partition = partition; + } + return self(); + } + + public S withConfig(Map config) { + config.forEach(this::withConfig); + return self(); + } + + public S withConfig(String key, String value) { + if (isProxyStream()) { + originalStream.withConfig(key, value); + } else { + this.config.put(key, value); + } + return self(); + } + + @SuppressWarnings("unchecked") + public Map getConfig() { + return isProxyStream() ? originalStream.getConfig() : config; + } + + public boolean isProxyStream() { + return originalStream != null; + } + + public Stream getOriginalStream() { + Preconditions.checkArgument(isProxyStream()); + return originalStream; + } + + /** Set chain strategy for this stream */ + public S withChainStrategy(ChainStrategy chainStrategy) { + Preconditions.checkArgument(!isProxyStream()); + operator.setChainStrategy(chainStrategy); + return self(); + } + + /** Disable chain for this stream */ + public S disableChain() { + return withChainStrategy(ChainStrategy.NEVER); + } + + /** + * Set the partition function of this {@link Stream} so that output elements are forwarded to next + * operator locally. + */ + public S forward() { + return setPartition(getForwardPartition(operator)); + } + + public abstract Language getLanguage(); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java new file mode 100644 index 00000000..1dd9fcd4 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java @@ -0,0 +1,15 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.operator.StreamOperator; + +/** + * Represents a sink of the Stream. + * + * @param Type of the input data of this sink. + */ +public abstract class StreamSink extends Stream, T> { + + public StreamSink(Stream inputStream, StreamOperator streamOperator) { + super(inputStream, streamOperator); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSource.java new file mode 100644 index 00000000..0d21b4f5 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSource.java @@ -0,0 +1,8 @@ +package io.ray.streaming.api.stream; + +/** + * A mark interface that represents a source of the Stream. + * + * @param The type of StreamSource data. + */ +public interface StreamSource {} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java new file mode 100644 index 00000000..d2123e1e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java @@ -0,0 +1,38 @@ +package io.ray.streaming.api.stream; + +import io.ray.streaming.operator.impl.UnionOperator; +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a union DataStream. + * + *

This stream does not create a physical operation, it only affects how upstream data are + * connected to downstream data. + * + * @param The type of union data. + */ +public class UnionStream extends DataStream { + + private List> unionStreams; + + public UnionStream(DataStream input, List> streams) { + // Union stream does not create a physical operation, so we don't have to set partition + // function for it. + super(input, new UnionOperator()); + this.unionStreams = new ArrayList<>(); + streams.forEach(this::addStream); + } + + void addStream(DataStream stream) { + if (stream instanceof UnionStream) { + this.unionStreams.addAll(((UnionStream) stream).getUnionStreams()); + } else { + this.unionStreams.add(stream); + } + } + + public List> getUnionStreams() { + return unionStreams; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/client/JobClient.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/client/JobClient.java new file mode 100644 index 00000000..4a67f8c6 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/client/JobClient.java @@ -0,0 +1,15 @@ +package io.ray.streaming.client; + +import io.ray.streaming.jobgraph.JobGraph; +import java.util.Map; + +/** Interface of the job client. */ +public interface JobClient { + + /** + * Submit job with logical plan to run. + * + * @param jobGraph The logical plan. + */ + void submit(JobGraph jobGraph, Map conf); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobEdge.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobEdge.java new file mode 100644 index 00000000..186ff3b2 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobEdge.java @@ -0,0 +1,54 @@ +package io.ray.streaming.jobgraph; + +import io.ray.streaming.api.partition.Partition; +import java.io.Serializable; + +/** Job edge is connection and partition rules of upstream and downstream execution nodes. */ +public class JobEdge implements Serializable { + + private int srcVertexId; + private int targetVertexId; + private Partition partition; + + public JobEdge(int srcVertexId, int targetVertexId, Partition partition) { + this.srcVertexId = srcVertexId; + this.targetVertexId = targetVertexId; + this.partition = partition; + } + + public int getSrcVertexId() { + return srcVertexId; + } + + public void setSrcVertexId(int srcVertexId) { + this.srcVertexId = srcVertexId; + } + + public int getTargetVertexId() { + return targetVertexId; + } + + public void setTargetVertexId(int targetVertexId) { + this.targetVertexId = targetVertexId; + } + + public Partition getPartition() { + return partition; + } + + public void setPartition(Partition partition) { + this.partition = partition; + } + + @Override + public String toString() { + return "Edge(" + + "from:" + + srcVertexId + + "-" + + targetVertexId + + "-" + + this.partition.getClass() + + ")"; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java new file mode 100644 index 00000000..b192dbcc --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java @@ -0,0 +1,139 @@ +package io.ray.streaming.jobgraph; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Job graph, the logical plan of streaming job. */ +public class JobGraph implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(JobGraph.class); + + private final String jobName; + private final Map jobConfig; + private List jobVertices; + private List jobEdges; + private String digraph; + + public JobGraph(String jobName, Map jobConfig) { + this.jobName = jobName; + this.jobConfig = jobConfig; + this.jobVertices = new ArrayList<>(); + this.jobEdges = new ArrayList<>(); + } + + public JobGraph( + String jobName, + Map jobConfig, + List jobVertices, + List jobEdges) { + this.jobName = jobName; + this.jobConfig = jobConfig; + this.jobVertices = jobVertices; + this.jobEdges = jobEdges; + generateDigraph(); + } + + /** + * Generate direct-graph(made up of a set of vertices and connected by edges) by current job graph + * for simple log printing. + * + * @return Digraph in string type. + */ + public String generateDigraph() { + StringBuilder digraph = new StringBuilder(); + digraph.append("digraph ").append(jobName).append(" ").append(" {"); + + for (JobEdge jobEdge : jobEdges) { + String srcNode = null; + String targetNode = null; + for (JobVertex jobVertex : jobVertices) { + if (jobEdge.getSrcVertexId() == jobVertex.getVertexId()) { + srcNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName(); + } else if (jobEdge.getTargetVertexId() == jobVertex.getVertexId()) { + targetNode = jobVertex.getVertexId() + "-" + jobVertex.getStreamOperator().getName(); + } + } + digraph.append(System.getProperty("line.separator")); + digraph.append(String.format(" \"%s\" -> \"%s\"", srcNode, targetNode)); + } + digraph.append(System.getProperty("line.separator")).append("}"); + + this.digraph = digraph.toString(); + return this.digraph; + } + + public void addVertex(JobVertex vertex) { + this.jobVertices.add(vertex); + } + + public void addEdge(JobEdge jobEdge) { + this.jobEdges.add(jobEdge); + } + + public List getJobVertices() { + return jobVertices; + } + + public List getSourceVertices() { + return jobVertices.stream() + .filter(v -> v.getVertexType() == VertexType.SOURCE) + .collect(Collectors.toList()); + } + + public List getSinkVertices() { + return jobVertices.stream() + .filter(v -> v.getVertexType() == VertexType.SINK) + .collect(Collectors.toList()); + } + + public JobVertex getVertex(int vertexId) { + return jobVertices.stream().filter(v -> v.getVertexId() == vertexId).findFirst().get(); + } + + public List getJobEdges() { + return jobEdges; + } + + public Set getVertexInputEdges(int vertexId) { + return jobEdges.stream() + .filter(jobEdge -> jobEdge.getTargetVertexId() == vertexId) + .collect(Collectors.toSet()); + } + + public Set getVertexOutputEdges(int vertexId) { + return jobEdges.stream() + .filter(jobEdge -> jobEdge.getSrcVertexId() == vertexId) + .collect(Collectors.toSet()); + } + + public String getDigraph() { + return digraph; + } + + public String getJobName() { + return jobName; + } + + public Map getJobConfig() { + return jobConfig; + } + + public void printJobGraph() { + if (!LOG.isInfoEnabled()) { + return; + } + LOG.info("Printing job graph:"); + for (JobVertex jobVertex : jobVertices) { + LOG.info(jobVertex.toString()); + } + for (JobEdge jobEdge : jobEdges) { + LOG.info(jobEdge.toString()); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java new file mode 100644 index 00000000..7de388c0 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java @@ -0,0 +1,117 @@ +package io.ray.streaming.jobgraph; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.JoinStream; +import io.ray.streaming.api.stream.Stream; +import io.ray.streaming.api.stream.StreamSink; +import io.ray.streaming.api.stream.StreamSource; +import io.ray.streaming.api.stream.UnionStream; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonUnionStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JobGraphBuilder { + + private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class); + + private JobGraph jobGraph; + + private AtomicInteger edgeIdGenerator; + private List streamSinkList; + + public JobGraphBuilder(List streamSinkList) { + this(streamSinkList, "job_" + System.currentTimeMillis()); + } + + public JobGraphBuilder(List streamSinkList, String jobName) { + this(streamSinkList, jobName, new HashMap<>()); + } + + public JobGraphBuilder( + List streamSinkList, String jobName, Map jobConfig) { + this.jobGraph = new JobGraph(jobName, jobConfig); + this.streamSinkList = streamSinkList; + this.edgeIdGenerator = new AtomicInteger(0); + } + + public JobGraph build() { + for (StreamSink streamSink : streamSinkList) { + processStream(streamSink); + } + return this.jobGraph; + } + + @SuppressWarnings("unchecked") + private void processStream(Stream stream) { + while (stream.isProxyStream()) { + // Proxy stream and original stream are the same logical stream, both refer to the + // same data flow transformation. We should skip proxy stream to avoid applying same + // transformation multiple times. + LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId()); + stream = stream.getOriginalStream(); + } + StreamOperator streamOperator = stream.getOperator(); + Preconditions.checkArgument( + stream.getLanguage() == streamOperator.getLanguage(), + "Reference stream should be skipped."); + int vertexId = stream.getId(); + int parallelism = stream.getParallelism(); + Map config = stream.getConfig(); + JobVertex jobVertex; + if (stream instanceof StreamSink) { + jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator, config); + Stream parentStream = stream.getInputStream(); + int inputVertexId = parentStream.getId(); + JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); + this.jobGraph.addEdge(jobEdge); + processStream(parentStream); + } else if (stream instanceof StreamSource) { + jobVertex = new JobVertex(vertexId, parallelism, VertexType.SOURCE, streamOperator, config); + } else if (stream instanceof DataStream || stream instanceof PythonDataStream) { + jobVertex = + new JobVertex(vertexId, parallelism, VertexType.TRANSFORMATION, streamOperator, config); + Stream parentStream = stream.getInputStream(); + int inputVertexId = parentStream.getId(); + JobEdge jobEdge = new JobEdge(inputVertexId, vertexId, parentStream.getPartition()); + this.jobGraph.addEdge(jobEdge); + processStream(parentStream); + + // process union stream + List streams = new ArrayList<>(); + if (stream instanceof UnionStream) { + streams.addAll(((UnionStream) stream).getUnionStreams()); + } + if (stream instanceof PythonUnionStream) { + streams.addAll(((PythonUnionStream) stream).getUnionStreams()); + } + for (Stream otherStream : streams) { + JobEdge otherEdge = new JobEdge(otherStream.getId(), vertexId, otherStream.getPartition()); + this.jobGraph.addEdge(otherEdge); + processStream(otherStream); + } + + // process join stream + if (stream instanceof JoinStream) { + DataStream rightStream = ((JoinStream) stream).getRightStream(); + this.jobGraph.addEdge( + new JobEdge(rightStream.getId(), vertexId, rightStream.getPartition())); + processStream(rightStream); + } + } else { + throw new UnsupportedOperationException("Unsupported stream: " + stream); + } + this.jobGraph.addVertex(jobVertex); + } + + private int getEdgeId() { + return this.edgeIdGenerator.incrementAndGet(); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java new file mode 100644 index 00000000..7fbb981c --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphOptimizer.java @@ -0,0 +1,204 @@ +package io.ray.streaming.jobgraph; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.partition.impl.ForwardPartition; +import io.ray.streaming.api.partition.impl.RoundRobinPartition; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.chain.ChainedOperator; +import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonOperator.ChainedPythonOperator; +import io.ray.streaming.python.PythonPartition; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; + +/** + * Optimize job graph by chaining some operators so that some operators can be run in the same + * thread. + */ +public class JobGraphOptimizer { + + private final JobGraph jobGraph; + private Set visited = new HashSet<>(); + // vertex id -> vertex + private Map vertexMap; + private Map> outputEdgesMap; + // tail vertex id -> mergedVertex + private Map>> mergedVertexMap; + + public JobGraphOptimizer(JobGraph jobGraph) { + this.jobGraph = jobGraph; + vertexMap = + jobGraph.getJobVertices().stream() + .collect(Collectors.toMap(JobVertex::getVertexId, Function.identity())); + outputEdgesMap = + vertexMap.keySet().stream() + .collect( + Collectors.toMap( + id -> vertexMap.get(id), + id -> new HashSet<>(jobGraph.getVertexOutputEdges(id)))); + mergedVertexMap = new HashMap<>(); + } + + public JobGraph optimize() { + // Deep-first traverse nodes from source to sink to merge vertices that can be chained + // together. + jobGraph + .getSourceVertices() + .forEach( + vertex -> { + List verticesToMerge = new ArrayList<>(); + verticesToMerge.add(vertex); + mergeVerticesRecursively(vertex, verticesToMerge); + }); + + List vertices = + mergedVertexMap.values().stream().map(Pair::getLeft).collect(Collectors.toList()); + + return new JobGraph(jobGraph.getJobName(), jobGraph.getJobConfig(), vertices, createEdges()); + } + + private void mergeVerticesRecursively(JobVertex vertex, List verticesToMerge) { + if (!visited.contains(vertex)) { + visited.add(vertex); + Set outputEdges = outputEdgesMap.get(vertex); + if (outputEdges.isEmpty()) { + mergeAndAddVertex(verticesToMerge); + } else { + outputEdges.forEach( + edge -> { + JobVertex succeedingVertex = vertexMap.get(edge.getTargetVertexId()); + if (canBeChained(vertex, succeedingVertex, edge)) { + verticesToMerge.add(succeedingVertex); + mergeVerticesRecursively(succeedingVertex, verticesToMerge); + } else { + mergeAndAddVertex(verticesToMerge); + List newMergedVertices = new ArrayList<>(); + newMergedVertices.add(succeedingVertex); + mergeVerticesRecursively(succeedingVertex, newMergedVertices); + } + }); + } + } + } + + private void mergeAndAddVertex(List verticesToMerge) { + JobVertex mergedVertex; + JobVertex headVertex = verticesToMerge.get(0); + Language language = headVertex.getLanguage(); + if (verticesToMerge.size() == 1) { + // no chain + mergedVertex = headVertex; + } else { + List operators = + verticesToMerge.stream() + .map(v -> vertexMap.get(v.getVertexId()).getStreamOperator()) + .collect(Collectors.toList()); + List> configs = + verticesToMerge.stream() + .map(v -> vertexMap.get(v.getVertexId()).getConfig()) + .collect(Collectors.toList()); + StreamOperator operator; + if (language == Language.JAVA) { + operator = ChainedOperator.newChainedOperator(operators, configs); + } else { + List pythonOperators = + operators.stream().map(o -> (PythonOperator) o).collect(Collectors.toList()); + operator = new ChainedPythonOperator(pythonOperators, configs); + } + // chained operator config is placed into `ChainedOperator`. + mergedVertex = + new JobVertex( + headVertex.getVertexId(), + headVertex.getParallelism(), + headVertex.getVertexType(), + operator, + new HashMap<>()); + } + + mergedVertexMap.put(mergedVertex.getVertexId(), Pair.of(mergedVertex, verticesToMerge)); + } + + private List createEdges() { + List edges = new ArrayList<>(); + mergedVertexMap.forEach( + (id, pair) -> { + JobVertex mergedVertex = pair.getLeft(); + List mergedVertices = pair.getRight(); + JobVertex tailVertex = mergedVertices.get(mergedVertices.size() - 1); + // input edge will be set up in input vertices + if (outputEdgesMap.containsKey(tailVertex)) { + outputEdgesMap + .get(tailVertex) + .forEach( + edge -> { + Pair> downstreamPair = + mergedVertexMap.get(edge.getTargetVertexId()); + // change ForwardPartition to RoundRobinPartition. + Partition partition = changePartition(edge.getPartition()); + JobEdge newEdge = + new JobEdge( + mergedVertex.getVertexId(), + downstreamPair.getLeft().getVertexId(), + partition); + edges.add(newEdge); + }); + } + }); + return edges; + } + + /** Change ForwardPartition to RoundRobinPartition. */ + private Partition changePartition(Partition partition) { + if (partition instanceof PythonPartition) { + PythonPartition pythonPartition = (PythonPartition) partition; + if (!pythonPartition.isConstructedFromBinary() + && pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS)) { + return PythonPartition.RoundRobinPartition; + } else { + return partition; + } + } else { + if (partition instanceof ForwardPartition) { + return new RoundRobinPartition(); + } else { + return partition; + } + } + } + + private boolean canBeChained( + JobVertex precedingVertex, JobVertex succeedingVertex, JobEdge edge) { + if (jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 + || jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) { + return false; + } + if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) { + return false; + } + if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER + || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER + || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) { + return false; + } + if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) { + return false; + } + Partition partition = edge.getPartition(); + if (!(partition instanceof PythonPartition)) { + return partition instanceof ForwardPartition; + } else { + PythonPartition pythonPartition = (PythonPartition) partition; + return !pythonPartition.isConstructedFromBinary() + && pythonPartition.getFunctionName().equals(PythonPartition.FORWARD_PARTITION_CLASS); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java new file mode 100644 index 00000000..a8083246 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobVertex.java @@ -0,0 +1,72 @@ +package io.ray.streaming.jobgraph; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.api.Language; +import io.ray.streaming.operator.StreamOperator; +import java.io.Serializable; +import java.util.Map; + +/** Job vertex is a cell node where logic is executed. */ +public class JobVertex implements Serializable { + + private int vertexId; + private int parallelism; + private VertexType vertexType; + private Language language; + private StreamOperator streamOperator; + private Map config; + + public JobVertex( + int vertexId, + int parallelism, + VertexType vertexType, + StreamOperator streamOperator, + Map config) { + this.vertexId = vertexId; + this.parallelism = parallelism; + this.vertexType = vertexType; + this.streamOperator = streamOperator; + this.language = streamOperator.getLanguage(); + this.config = config; + } + + public int getVertexId() { + return vertexId; + } + + public int getParallelism() { + return parallelism; + } + + public StreamOperator getStreamOperator() { + return streamOperator; + } + + public VertexType getVertexType() { + return vertexType; + } + + public Language getLanguage() { + return language; + } + + public Map getConfig() { + return config; + } + + public void setConfig(Map config) { + this.config = config; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("vertexId", vertexId) + .add("parallelism", parallelism) + .add("vertexType", vertexType) + .add("language", language) + .add("streamOperator", streamOperator) + .add("config", config) + .toString(); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/VertexType.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/VertexType.java new file mode 100644 index 00000000..90dc3634 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/VertexType.java @@ -0,0 +1,8 @@ +package io.ray.streaming.jobgraph; + +/** Different roles for a node. */ +public enum VertexType { + SOURCE, + TRANSFORMATION, + SINK, +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java new file mode 100644 index 00000000..4eb65568 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java @@ -0,0 +1,41 @@ +package io.ray.streaming.message; + +import java.util.Objects; + +public class KeyRecord extends Record { + + private K key; + + public KeyRecord(K key, T value) { + super(value); + this.key = key; + } + + public K getKey() { + return key; + } + + public void setKey(K key) { + this.key = key; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + KeyRecord keyRecord = (KeyRecord) o; + return Objects.equals(key, keyRecord.key); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), key); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java new file mode 100644 index 00000000..6f209320 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java @@ -0,0 +1,52 @@ +package io.ray.streaming.message; + +import java.io.Serializable; +import java.util.Objects; + +public class Record implements Serializable { + + protected transient String stream; + protected T value; + + public Record(T value) { + this.value = value; + } + + public T getValue() { + return value; + } + + public void setValue(T value) { + this.value = value; + } + + public String getStream() { + return stream; + } + + public void setStream(String stream) { + this.stream = stream; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Record record = (Record) o; + return Objects.equals(stream, record.stream) && Objects.equals(value, record.value); + } + + @Override + public int hashCode() { + return Objects.hash(stream, value); + } + + @Override + public String toString() { + return value.toString(); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java new file mode 100644 index 00000000..275ebd8f --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/ChainStrategy.java @@ -0,0 +1,14 @@ +package io.ray.streaming.operator; + +/** Chain strategy for streaming operators. Chained operators are run in the same thread. */ +public enum ChainStrategy { + /** + * The operator won't be chained with preceding operators, but maybe chained with succeeding + * operators. + */ + HEAD, + /** Operators will be chained together when possible. */ + ALWAYS, + /** The operator won't be chained with any operator. */ + NEVER +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OneInputOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OneInputOperator.java new file mode 100644 index 00000000..c0f6c4df --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OneInputOperator.java @@ -0,0 +1,12 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.message.Record; + +public interface OneInputOperator extends Operator { + + void processElement(Record record) throws Exception; + + default OperatorType getOpType() { + return OperatorType.ONE_INPUT; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java new file mode 100644 index 00000000..3754385e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java @@ -0,0 +1,33 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import java.io.Serializable; +import java.util.List; + +public interface Operator extends Serializable { + + String getName(); + + void open(List collectors, RuntimeContext runtimeContext); + + void finish(); + + void close(); + + Function getFunction(); + + Language getLanguage(); + + OperatorType getOpType(); + + ChainStrategy getChainStrategy(); + + /** See {@link Function#saveCheckpoint()}. */ + Serializable saveCheckpoint(); + + /** See {@link Function#loadCheckpoint(Serializable)}. */ + void loadCheckpoint(Serializable checkpointObject); +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OperatorType.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OperatorType.java new file mode 100644 index 00000000..68ab987b --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/OperatorType.java @@ -0,0 +1,7 @@ +package io.ray.streaming.operator; + +public enum OperatorType { + SOURCE, + ONE_INPUT, + TWO_INPUT, +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java new file mode 100644 index 00000000..04ba1ded --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java @@ -0,0 +1,14 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; + +public interface SourceOperator extends Operator { + + void fetch(); + + SourceContext getSourceContext(); + + default OperatorType getOpType() { + return OperatorType.SOURCE; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java new file mode 100644 index 00000000..2c7980ab --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java @@ -0,0 +1,97 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.RichFunction; +import io.ray.streaming.api.function.internal.Functions; +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import java.io.Serializable; +import java.util.List; + +public abstract class StreamOperator implements Operator { + + protected final String name; + protected F function; + protected RichFunction richFunction; + protected List collectorList; + protected RuntimeContext runtimeContext; + private ChainStrategy chainStrategy = ChainStrategy.ALWAYS; + + protected StreamOperator() { + this.name = getClass().getSimpleName(); + } + + protected StreamOperator(F function) { + this(); + setFunction(function); + } + + public void setFunction(F function) { + this.function = function; + this.richFunction = Functions.wrap(function); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + this.collectorList = collectorList; + this.runtimeContext = runtimeContext; + richFunction.open(runtimeContext); + } + + @Override + public void finish() {} + + @Override + public void close() { + richFunction.close(); + } + + @Override + public Function getFunction() { + return function; + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } + + protected void collect(Record record) { + for (Collector collector : this.collectorList) { + collector.collect(record); + } + } + + protected void collect(KeyRecord keyRecord) { + for (Collector collector : this.collectorList) { + collector.collect(keyRecord); + } + } + + @Override + public Serializable saveCheckpoint() { + return function.saveCheckpoint(); + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + function.loadCheckpoint(checkpointObject); + } + + @Override + public String getName() { + return name; + } + + public void setChainStrategy(ChainStrategy chainStrategy) { + this.chainStrategy = chainStrategy; + } + + @Override + public ChainStrategy getChainStrategy() { + return chainStrategy; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/TwoInputOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/TwoInputOperator.java new file mode 100644 index 00000000..54e95984 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/TwoInputOperator.java @@ -0,0 +1,12 @@ +package io.ray.streaming.operator; + +import io.ray.streaming.message.Record; + +public interface TwoInputOperator extends Operator { + + void processElement(Record record1, Record record2); + + default OperatorType getOpType() { + return OperatorType.TWO_INPUT; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java new file mode 100644 index 00000000..62aa3e3b --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java @@ -0,0 +1,187 @@ +package io.ray.streaming.operator.chain; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.Language; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.Operator; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.TwoInputOperator; +import java.io.Serializable; +import java.lang.reflect.Proxy; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** Abstract base class for chained operators. */ +public abstract class ChainedOperator extends StreamOperator { + + protected final List operators; + protected final Operator headOperator; + protected final Operator tailOperator; + private final List> configs; + + public ChainedOperator(List operators, List> configs) { + Preconditions.checkArgument( + operators.size() >= 2, "Need at lease two operators to be chained together"); + operators.stream() + .skip(1) + .forEach(operator -> Preconditions.checkArgument(operator instanceof OneInputOperator)); + this.operators = operators; + this.configs = configs; + this.headOperator = operators.get(0); + this.tailOperator = operators.get(operators.size() - 1); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + // Dont' call super.open() as we `open` every operator separately. + List succeedingCollectors = + operators.stream() + .skip(1) + .map(operator -> new ForwardCollector((OneInputOperator) operator)) + .collect(Collectors.toList()); + for (int i = 0; i < operators.size() - 1; i++) { + StreamOperator operator = operators.get(i); + List forwardCollectors = + Collections.singletonList(succeedingCollectors.get(i)); + operator.open(forwardCollectors, createRuntimeContext(runtimeContext, i)); + } + // tail operator send data to downstream using provided collectors. + tailOperator.open(collectorList, createRuntimeContext(runtimeContext, operators.size() - 1)); + } + + @Override + public OperatorType getOpType() { + return headOperator.getOpType(); + } + + @Override + public Language getLanguage() { + return headOperator.getLanguage(); + } + + @Override + public String getName() { + return operators.stream().map(Operator::getName).collect(Collectors.joining(" -> ", "[", "]")); + } + + public List getOperators() { + return operators; + } + + public Operator getHeadOperator() { + return headOperator; + } + + public Operator getTailOperator() { + return tailOperator; + } + + @Override + public Serializable saveCheckpoint() { + Serializable[] checkpoints = new Serializable[operators.size()]; + for (int i = 0; i < operators.size(); ++i) { + checkpoints[i] = operators.get(i).saveCheckpoint(); + } + return checkpoints; + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + Serializable[] checkpoints = (Serializable[]) checkpointObject; + for (int i = 0; i < operators.size(); ++i) { + operators.get(i).loadCheckpoint(checkpoints[i]); + } + } + + private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) { + return (RuntimeContext) + Proxy.newProxyInstance( + runtimeContext.getClass().getClassLoader(), + new Class[] {RuntimeContext.class}, + (proxy, method, methodArgs) -> { + if (method.getName().equals("getConfig")) { + return configs.get(index); + } else { + return method.invoke(runtimeContext, methodArgs); + } + }); + } + + public static ChainedOperator newChainedOperator( + List operators, List> configs) { + switch (operators.get(0).getOpType()) { + case SOURCE: + return new ChainedSourceOperator(operators, configs); + case ONE_INPUT: + return new ChainedOneInputOperator(operators, configs); + case TWO_INPUT: + return new ChainedTwoInputOperator(operators, configs); + default: + throw new IllegalArgumentException( + "Unsupported operator type " + operators.get(0).getOpType()); + } + } + + static class ChainedSourceOperator extends ChainedOperator implements SourceOperator { + + private final SourceOperator sourceOperator; + + @SuppressWarnings("unchecked") + ChainedSourceOperator(List operators, List> configs) { + super(operators, configs); + sourceOperator = (SourceOperator) headOperator; + } + + @Override + public void fetch() { + sourceOperator.fetch(); + } + + @Override + public SourceContext getSourceContext() { + return sourceOperator.getSourceContext(); + } + } + + static class ChainedOneInputOperator extends ChainedOperator implements OneInputOperator { + + private final OneInputOperator inputOperator; + + @SuppressWarnings("unchecked") + ChainedOneInputOperator(List operators, List> configs) { + super(operators, configs); + inputOperator = (OneInputOperator) headOperator; + } + + @Override + public void processElement(Record record) throws Exception { + inputOperator.processElement(record); + } + } + + static class ChainedTwoInputOperator extends ChainedOperator + implements TwoInputOperator { + + private final TwoInputOperator inputOperator; + + @SuppressWarnings("unchecked") + ChainedTwoInputOperator(List operators, List> configs) { + super(operators, configs); + inputOperator = (TwoInputOperator) headOperator; + } + + @Override + public void processElement(Record record1, Record record2) { + inputOperator.processElement(record1, record2); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java new file mode 100644 index 00000000..36680dae --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ForwardCollector.java @@ -0,0 +1,24 @@ +package io.ray.streaming.operator.chain; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; + +class ForwardCollector implements Collector { + + private final OneInputOperator succeedingOperator; + + ForwardCollector(OneInputOperator succeedingOperator) { + this.succeedingOperator = succeedingOperator; + } + + @SuppressWarnings("unchecked") + @Override + public void collect(Record record) { + try { + succeedingOperator.processElement(record); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FilterOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FilterOperator.java new file mode 100644 index 00000000..f4be1b50 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FilterOperator.java @@ -0,0 +1,21 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.FilterFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class FilterOperator extends StreamOperator> + implements OneInputOperator { + + public FilterOperator(FilterFunction filterFunction) { + super(filterFunction); + } + + @Override + public void processElement(Record record) throws Exception { + if (this.function.filter(record.getValue())) { + this.collect(record); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FlatMapOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FlatMapOperator.java new file mode 100644 index 00000000..0fa6e818 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/FlatMapOperator.java @@ -0,0 +1,31 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.collector.CollectionCollector; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.impl.FlatMapFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; +import java.util.List; + +public class FlatMapOperator extends StreamOperator> + implements OneInputOperator { + + private CollectionCollector collectionCollector; + + public FlatMapOperator(FlatMapFunction flatMapFunction) { + super(flatMapFunction); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + super.open(collectorList, runtimeContext); + this.collectionCollector = new CollectionCollector(collectorList); + } + + @Override + public void processElement(Record record) throws Exception { + this.function.flatMap(record.getValue(), (Collector) collectionCollector); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java new file mode 100644 index 00000000..d5cb600e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/JoinOperator.java @@ -0,0 +1,35 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.JoinFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.TwoInputOperator; + +/** + * Join operator + * + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the join key. + * @param Type of the data in the joined stream. + */ +public class JoinOperator extends StreamOperator> + implements TwoInputOperator { + + public JoinOperator() {} + + public JoinOperator(JoinFunction function) { + super(function); + setChainStrategy(ChainStrategy.HEAD); + } + + @Override + public void processElement(Record record1, Record record2) {} + + @Override + public OperatorType getOpType() { + return OperatorType.TWO_INPUT; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/KeyByOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/KeyByOperator.java new file mode 100644 index 00000000..0e182833 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/KeyByOperator.java @@ -0,0 +1,21 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.KeyFunction; +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class KeyByOperator extends StreamOperator> + implements OneInputOperator { + + public KeyByOperator(KeyFunction keyFunction) { + super(keyFunction); + } + + @Override + public void processElement(Record record) throws Exception { + K key = this.function.keyBy(record.getValue()); + collect(new KeyRecord<>(key, record.getValue())); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/MapOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/MapOperator.java new file mode 100644 index 00000000..7d6a7e3e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/MapOperator.java @@ -0,0 +1,19 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.MapFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class MapOperator extends StreamOperator> + implements OneInputOperator { + + public MapOperator(MapFunction mapFunction) { + super(mapFunction); + } + + @Override + public void processElement(Record record) throws Exception { + this.collect(new Record(this.function.map(record.getValue()))); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java new file mode 100644 index 00000000..8ec8a637 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/ReduceOperator.java @@ -0,0 +1,46 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.impl.ReduceFunction; +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ReduceOperator extends StreamOperator> + implements OneInputOperator { + + private Map reduceState; + + public ReduceOperator(ReduceFunction reduceFunction) { + super(reduceFunction); + setChainStrategy(ChainStrategy.HEAD); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + super.open(collectorList, runtimeContext); + this.reduceState = new HashMap<>(); + } + + @Override + public void processElement(Record record) throws Exception { + KeyRecord keyRecord = (KeyRecord) record; + K key = keyRecord.getKey(); + T value = keyRecord.getValue(); + if (reduceState.containsKey(key)) { + T oldValue = reduceState.get(key); + T newValue = this.function.reduce(oldValue, value); + reduceState.put(key, newValue); + collect(new Record(newValue)); + } else { + reduceState.put(key, value); + collect(record); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SinkOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SinkOperator.java new file mode 100644 index 00000000..40808e8e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SinkOperator.java @@ -0,0 +1,19 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class SinkOperator extends StreamOperator> + implements OneInputOperator { + + public SinkOperator(SinkFunction sinkFunction) { + super(sinkFunction); + } + + @Override + public void processElement(Record record) throws Exception { + this.function.sink(record.getValue()); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java new file mode 100644 index 00000000..354ada69 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java @@ -0,0 +1,65 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.impl.SourceFunction; +import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; +import io.ray.streaming.operator.StreamOperator; +import java.util.List; + +public class SourceOperatorImpl extends StreamOperator> + implements SourceOperator { + + private SourceContextImpl sourceContext; + + public SourceOperatorImpl(SourceFunction function) { + super(function); + setChainStrategy(ChainStrategy.HEAD); + } + + @Override + public void open(List collectorList, RuntimeContext runtimeContext) { + super.open(collectorList, runtimeContext); + this.sourceContext = new SourceContextImpl(collectorList); + this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex()); + } + + @Override + public void fetch() { + try { + this.function.fetch(this.sourceContext); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public SourceContext getSourceContext() { + return sourceContext; + } + + @Override + public OperatorType getOpType() { + return OperatorType.SOURCE; + } + + class SourceContextImpl implements SourceContext { + + private List collectors; + + public SourceContextImpl(List collectors) { + this.collectors = collectors; + } + + @Override + public void collect(T t) throws Exception { + for (Collector collector : collectors) { + collector.collect(new Record<>(t)); + } + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java new file mode 100644 index 00000000..2b916e80 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/UnionOperator.java @@ -0,0 +1,19 @@ +package io.ray.streaming.operator.impl; + +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.function.internal.Functions; +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.StreamOperator; + +public class UnionOperator extends StreamOperator implements OneInputOperator { + + public UnionOperator() { + super(Functions.emptyFunction()); + } + + @Override + public void processElement(Record record) { + collect(record); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java new file mode 100644 index 00000000..87ef76d2 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java @@ -0,0 +1,124 @@ +package io.ray.streaming.python; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.function.Function; +import java.util.StringJoiner; +import org.apache.commons.lang3.StringUtils; + +/** + * Represents a user defined python function. + * + *

Python worker can use information in this class to create a function object. + * + *

If this object is constructed from serialized python function, python worker can deserialize + * it to create python function directly. If this object is constructed from moduleName and + * className/functionName, python worker will use `importlib` to load python function. + * + *

If the python data stream api is invoked from python, `function` will be not null. + * + *

If the python data stream api is invoked from java, `moduleName` and `functionName` will be + * not null. + * + *

+ */ +public class PythonFunction implements Function { + + public enum FunctionInterface { + SOURCE_FUNCTION("SourceFunction"), + MAP_FUNCTION("MapFunction"), + FLAT_MAP_FUNCTION("FlatMapFunction"), + FILTER_FUNCTION("FilterFunction"), + KEY_FUNCTION("KeyFunction"), + REDUCE_FUNCTION("ReduceFunction"), + SINK_FUNCTION("SinkFunction"); + + private String functionInterface; + + /** @param functionInterface function class name in `ray.streaming.function` module. */ + FunctionInterface(String functionInterface) { + this.functionInterface = functionInterface; + } + } + + // null if this function is constructed from moduleName/functionName. + private final byte[] function; + // null if this function is constructed from serialized python function. + private final String moduleName; + // null if this function is constructed from serialized python function. + private final String functionName; + /** + * FunctionInterface can be used to validate python function, and look up operator class from + * FunctionInterface. + */ + private String functionInterface; + + /** + * Create a {@link PythonFunction} from a serialized streaming python function. + * + * @param function serialized streaming python function from python driver. + */ + public PythonFunction(byte[] function) { + Preconditions.checkNotNull(function); + this.function = function; + this.moduleName = null; + this.functionName = null; + } + + /** + * Create a {@link PythonFunction} from a moduleName and streaming function name. + * + * @param moduleName module name of streaming function. + * @param functionName function name of streaming function. {@code functionName} is the name of a + * python function, or class name of subclass of `ray.streaming.function.` + */ + public PythonFunction(String moduleName, String functionName) { + Preconditions.checkArgument(StringUtils.isNotBlank(moduleName)); + Preconditions.checkArgument(StringUtils.isNotBlank(functionName)); + this.function = null; + this.moduleName = moduleName; + this.functionName = functionName; + } + + public void setFunctionInterface(FunctionInterface functionInterface) { + this.functionInterface = functionInterface.functionInterface; + } + + public byte[] getFunction() { + return function; + } + + public String getModuleName() { + return moduleName; + } + + public String getFunctionName() { + return functionName; + } + + public String getFunctionInterface() { + return functionInterface; + } + + public String toSimpleString() { + if (function != null) { + return "binary function"; + } else { + return String.format("%s-%s.%s", functionInterface, moduleName, functionName); + } + } + + @Override + public String toString() { + StringJoiner stringJoiner = + new StringJoiner(", ", PythonFunction.class.getSimpleName() + "[", "]"); + if (function != null) { + stringJoiner.add("function=binary function"); + } else { + stringJoiner + .add("moduleName='" + moduleName + "'") + .add("functionName='" + functionName + "'"); + } + stringJoiner.add("functionInterface='" + functionInterface + "'"); + return stringJoiner.toString(); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java new file mode 100644 index 00000000..729803e1 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonOperator.java @@ -0,0 +1,171 @@ +package io.ray.streaming.python; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.Language; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.operator.Operator; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.StreamOperator; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import java.util.stream.Collectors; + +/** Represents a {@link StreamOperator} that wraps python {@link PythonFunction}. */ +@SuppressWarnings("unchecked") +public class PythonOperator extends StreamOperator { + + private final String moduleName; + private final String className; + + public PythonOperator(String moduleName, String className) { + super(null); + this.moduleName = moduleName; + this.className = className; + } + + public PythonOperator(PythonFunction function) { + super(function); + this.moduleName = null; + this.className = null; + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } + + public String getModuleName() { + return moduleName; + } + + public String getClassName() { + return className; + } + + @Override + public void open(List list, RuntimeContext runtimeContext) { + throwUnsupportedException(); + } + + @Override + public void finish() { + throwUnsupportedException(); + } + + @Override + public void close() { + throwUnsupportedException(); + } + + void throwUnsupportedException() { + StackTraceElement[] trace = Thread.currentThread().getStackTrace(); + Preconditions.checkState(trace.length >= 2); + StackTraceElement traceElement = trace[2]; + String msg = + String.format( + "Method %s.%s shouldn't be called.", + traceElement.getClassName(), traceElement.getMethodName()); + throw new UnsupportedOperationException(msg); + } + + @Override + public OperatorType getOpType() { + String msg = String.format("Methods of %s shouldn't be called.", getClass().getSimpleName()); + throw new UnsupportedOperationException(msg); + } + + @Override + public String getName() { + StringBuilder builder = new StringBuilder(); + builder.append(PythonOperator.class.getSimpleName()).append("["); + if (function != null) { + builder.append(((PythonFunction) function).toSimpleString()); + } else { + builder.append(moduleName).append(".").append(className); + } + return builder.append("]").toString(); + } + + @Override + public String toString() { + StringJoiner stringJoiner = + new StringJoiner(", ", PythonOperator.class.getSimpleName() + "[", "]"); + if (function != null) { + stringJoiner.add("function='" + function + "'"); + } else { + stringJoiner.add("moduleName='" + moduleName + "'").add("className='" + className + "'"); + } + return stringJoiner.toString(); + } + + public static class ChainedPythonOperator extends PythonOperator { + + private final List operators; + private final PythonOperator headOperator; + private final PythonOperator tailOperator; + private final List> configs; + + public ChainedPythonOperator( + List operators, List> configs) { + super(null); + Preconditions.checkArgument(!operators.isEmpty()); + this.operators = operators; + this.configs = configs; + this.headOperator = operators.get(0); + this.tailOperator = operators.get(operators.size() - 1); + } + + @Override + public OperatorType getOpType() { + return headOperator.getOpType(); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } + + @Override + public String getName() { + return operators.stream() + .map(Operator::getName) + .collect(Collectors.joining(" -> ", "[", "]")); + } + + @Override + public String getModuleName() { + throwUnsupportedException(); + return null; // impossible + } + + @Override + public String getClassName() { + throwUnsupportedException(); + return null; // impossible + } + + @Override + public Function getFunction() { + throwUnsupportedException(); + return null; // impossible + } + + public List getOperators() { + return operators; + } + + public PythonOperator getHeadOperator() { + return headOperator; + } + + public PythonOperator getTailOperator() { + return tailOperator; + } + + public List> getConfigs() { + return configs; + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java new file mode 100644 index 00000000..e6d80836 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java @@ -0,0 +1,89 @@ +package io.ray.streaming.python; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.partition.Partition; +import java.util.StringJoiner; +import org.apache.commons.lang3.StringUtils; + +/** + * Represents a python partition function. + * + *

Python worker can create a partition object using information in this PythonPartition. + * + *

If this object is constructed from serialized python partition, python worker can deserialize + * it to create python partition directly. If this object is constructed from moduleName and + * className/functionName, python worker will use `importlib` to load python partition function. + * + *

+ */ +public class PythonPartition implements Partition { + + public static final PythonPartition BroadcastPartition = + new PythonPartition("ray.streaming.partition", "BroadcastPartition"); + public static final PythonPartition KeyPartition = + new PythonPartition("ray.streaming.partition", "KeyPartition"); + public static final PythonPartition RoundRobinPartition = + new PythonPartition("ray.streaming.partition", "RoundRobinPartition"); + public static final String FORWARD_PARTITION_CLASS = "ForwardPartition"; + public static final PythonPartition ForwardPartition = + new PythonPartition("ray.streaming.partition", FORWARD_PARTITION_CLASS); + + private byte[] partition; + private String moduleName; + private String functionName; + + public PythonPartition(byte[] partition) { + Preconditions.checkNotNull(partition); + this.partition = partition; + } + + /** + * Create a python partition from a moduleName and partition function name + * + * @param moduleName module name of python partition + * @param functionName function/class name of the partition function. + */ + public PythonPartition(String moduleName, String functionName) { + Preconditions.checkArgument(StringUtils.isNotBlank(moduleName)); + Preconditions.checkArgument(StringUtils.isNotBlank(functionName)); + this.moduleName = moduleName; + this.functionName = functionName; + } + + @Override + public int[] partition(Object record, int numPartition) { + String msg = + String.format("partition method of %s shouldn't be called.", getClass().getSimpleName()); + throw new UnsupportedOperationException(msg); + } + + public byte[] getPartition() { + return partition; + } + + public String getModuleName() { + return moduleName; + } + + public String getFunctionName() { + return functionName; + } + + public boolean isConstructedFromBinary() { + return partition != null; + } + + @Override + public String toString() { + StringJoiner stringJoiner = + new StringJoiner(", ", PythonPartition.class.getSimpleName() + "[", "]"); + if (partition != null) { + stringJoiner.add("partition=binary partition"); + } else { + stringJoiner + .add("moduleName='" + moduleName + "'") + .add("functionName='" + functionName + "'"); + } + return stringJoiner.toString(); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java new file mode 100644 index 00000000..90f018ec --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java @@ -0,0 +1,202 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.Stream; +import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonFunction.FunctionInterface; +import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonPartition; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** Represents a stream of data whose transformations will be executed in python. */ +public class PythonDataStream extends Stream implements PythonStream { + + protected PythonDataStream(StreamingContext streamingContext, PythonOperator pythonOperator) { + super(streamingContext, pythonOperator); + } + + protected PythonDataStream( + StreamingContext streamingContext, + PythonOperator pythonOperator, + Partition partition) { + super(streamingContext, pythonOperator, partition); + } + + public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) { + super(input, pythonOperator); + } + + public PythonDataStream( + PythonDataStream input, PythonOperator pythonOperator, Partition partition) { + super(input, pythonOperator, partition); + } + + /** + * Create a python stream that reference passed java stream. Changes in new stream will be + * reflected in referenced stream and vice versa + */ + public PythonDataStream(DataStream referencedStream) { + super(referencedStream); + } + + public PythonDataStream map(String moduleName, String funcName) { + return map(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a map function to this stream. + * + * @param func The python MapFunction. + * @return A new PythonDataStream. + */ + public PythonDataStream map(PythonFunction func) { + func.setFunctionInterface(FunctionInterface.MAP_FUNCTION); + return new PythonDataStream(this, new PythonOperator(func)); + } + + public PythonDataStream flatMap(String moduleName, String funcName) { + return flatMap(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a flat-map function to this stream. + * + * @param func The python FlapMapFunction. + * @return A new PythonDataStream + */ + public PythonDataStream flatMap(PythonFunction func) { + func.setFunctionInterface(FunctionInterface.FLAT_MAP_FUNCTION); + return new PythonDataStream(this, new PythonOperator(func)); + } + + public PythonDataStream filter(String moduleName, String funcName) { + return filter(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a filter function to this stream. + * + * @param func The python FilterFunction. + * @return A new PythonDataStream that contains only the elements satisfying the given filter + * predicate. + */ + public PythonDataStream filter(PythonFunction func) { + func.setFunctionInterface(FunctionInterface.FILTER_FUNCTION); + return new PythonDataStream(this, new PythonOperator(func)); + } + + /** + * Apply union transformations to this stream by merging {@link PythonDataStream} outputs of the + * same type with each other. + * + * @param stream The DataStream to union output with. + * @param others The other DataStreams to union output with. + * @return A new UnionStream. + */ + public final PythonDataStream union(PythonDataStream stream, PythonDataStream... others) { + List streams = new ArrayList<>(); + streams.add(stream); + streams.addAll(Arrays.asList(others)); + return union(streams); + } + + /** + * Apply union transformations to this stream by merging {@link PythonDataStream} outputs of the + * same type with each other. + * + * @param streams The DataStreams to union output with. + * @return A new UnionStream. + */ + public final PythonDataStream union(List streams) { + if (this instanceof PythonUnionStream) { + PythonUnionStream unionStream = (PythonUnionStream) this; + streams.forEach(unionStream::addStream); + return unionStream; + } else { + return new PythonUnionStream(this, streams); + } + } + + public PythonStreamSink sink(String moduleName, String funcName) { + return sink(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a sink function and get a StreamSink. + * + * @param func The python SinkFunction. + * @return A new StreamSink. + */ + public PythonStreamSink sink(PythonFunction func) { + func.setFunctionInterface(FunctionInterface.SINK_FUNCTION); + return new PythonStreamSink(this, new PythonOperator(func)); + } + + public PythonKeyDataStream keyBy(String moduleName, String funcName) { + return keyBy(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a key-by function to this stream. + * + * @param func the python keyFunction. + * @return A new KeyDataStream. + */ + public PythonKeyDataStream keyBy(PythonFunction func) { + checkPartitionCall(); + func.setFunctionInterface(FunctionInterface.KEY_FUNCTION); + return new PythonKeyDataStream(this, new PythonOperator(func)); + } + + /** + * Apply broadcast to this stream. + * + * @return This stream. + */ + public PythonDataStream broadcast() { + checkPartitionCall(); + return setPartition(PythonPartition.BroadcastPartition); + } + + /** + * Apply a partition to this stream. + * + * @param partition The partitioning strategy. + * @return This stream. + */ + public PythonDataStream partitionBy(PythonPartition partition) { + checkPartitionCall(); + return setPartition(partition); + } + + /** + * If parent stream is a python stream, we can't call partition related methods in the java + * stream. + */ + private void checkPartitionCall() { + if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) { + throw new RuntimeException( + "Partition related methods can't be called on a " + + "python stream if parent stream is a java stream."); + } + } + + /** + * Convert this stream as a java stream. The converted stream and this stream are the same logical + * stream, which has same stream id. Changes in converted stream will be reflected in this stream + * and vice versa. + */ + public DataStream asJavaStream() { + return new DataStream<>(this); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java new file mode 100644 index 00000000..078f84ac --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java @@ -0,0 +1,52 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.KeyDataStream; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonFunction.FunctionInterface; +import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonPartition; + +/** Represents a python DataStream returned by a key-by operation. */ +@SuppressWarnings("unchecked") +public class PythonKeyDataStream extends PythonDataStream implements PythonStream { + + public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) { + super(input, pythonOperator, PythonPartition.KeyPartition); + } + + /** + * Create a python stream that reference passed python stream. Changes in new stream will be + * reflected in referenced stream and vice versa + */ + public PythonKeyDataStream(DataStream referencedStream) { + super(referencedStream); + } + + public PythonDataStream reduce(String moduleName, String funcName) { + return reduce(new PythonFunction(moduleName, funcName)); + } + + /** + * Apply a reduce function to this stream. + * + * @param func The reduce function. + * @return A new DataStream. + */ + public PythonDataStream reduce(PythonFunction func) { + func.setFunctionInterface(FunctionInterface.REDUCE_FUNCTION); + PythonDataStream stream = new PythonDataStream(this, new PythonOperator(func)); + stream.withChainStrategy(ChainStrategy.HEAD); + return stream; + } + + /** + * Convert this stream as a java stream. The converted stream and this stream are the same logical + * stream, which has same stream id. Changes in converted stream will be reflected in this stream + * and vice versa. + */ + public KeyDataStream asJavaStream() { + return new KeyDataStream(this); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStream.java new file mode 100644 index 00000000..c89d23ca --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStream.java @@ -0,0 +1,4 @@ +package io.ray.streaming.python.stream; + +/** A marker interface used to identify all python streams. */ +public interface PythonStream {} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java new file mode 100644 index 00000000..6f30d50a --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java @@ -0,0 +1,19 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.api.Language; +import io.ray.streaming.api.stream.StreamSink; +import io.ray.streaming.python.PythonOperator; + +/** Represents a sink of the PythonStream. */ +public class PythonStreamSink extends StreamSink implements PythonStream { + + public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) { + super(input, sinkOperator); + getStreamingContext().addSink(this); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java new file mode 100644 index 00000000..35af0012 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java @@ -0,0 +1,23 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.StreamSource; +import io.ray.streaming.operator.ChainStrategy; +import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonFunction.FunctionInterface; +import io.ray.streaming.python.PythonOperator; + +/** Represents a source of the PythonStream. */ +public class PythonStreamSource extends PythonDataStream implements StreamSource { + + private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) { + super(streamingContext, new PythonOperator(sourceFunction)); + withChainStrategy(ChainStrategy.HEAD); + } + + public static PythonStreamSource from( + StreamingContext streamingContext, PythonFunction sourceFunction) { + sourceFunction.setFunctionInterface(FunctionInterface.SOURCE_FUNCTION); + return new PythonStreamSource(streamingContext, sourceFunction); + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java new file mode 100644 index 00000000..45c16ab6 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonUnionStream.java @@ -0,0 +1,36 @@ +package io.ray.streaming.python.stream; + +import io.ray.streaming.python.PythonOperator; +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a union DataStream. + * + *

This stream does not create a physical operation, it only affects how upstream data are + * connected to downstream data. + */ +public class PythonUnionStream extends PythonDataStream { + + private List unionStreams; + + public PythonUnionStream(PythonDataStream input, List others) { + // Union stream does not create a physical operation, so we don't have to set partition + // function for it. + super(input, new PythonOperator("ray.streaming.operator", "UnionOperator")); + this.unionStreams = new ArrayList<>(); + others.forEach(this::addStream); + } + + void addStream(PythonDataStream stream) { + if (stream instanceof PythonUnionStream) { + this.unionStreams.addAll(((PythonUnionStream) stream).getUnionStreams()); + } else { + this.unionStreams.add(stream); + } + } + + public List getUnionStreams() { + return unionStreams; + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java new file mode 100644 index 00000000..6238384e --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java @@ -0,0 +1,33 @@ +package io.ray.streaming.util; + +public class Config { + + public static final String STREAMING_JOB_NAME = "streaming.job.name"; + public static final String STREAMING_OP_NAME = "streaming.op_name"; + public static final String STREAMING_WORKER_NAME = "streaming.worker_name"; + + // channel + public static final String CHANNEL_TYPE = "channel_type"; + public static final String MEMORY_CHANNEL = "memory_channel"; + public static final String NATIVE_CHANNEL = "native_channel"; + public static final String CHANNEL_SIZE = "channel_size"; + public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long) Math.pow(10, 8)); + public static final String IS_RECREATE = "streaming.is_recreate"; + // return from DataReader.getBundle if only empty message read in this interval. + public static final String TIMER_INTERVAL_MS = "timer_interval_ms"; + public static final String READ_TIMEOUT_MS = "read_timeout_ms"; + public static final String DEFAULT_READ_TIMEOUT_MS = "10"; + + public static final String STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity"; + // write an empty message if there is no data to be written in this + // interval. + public static final String STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval"; + + // operator type + public static final String OPERATOR_TYPE = "operator_type"; + + // flow control + public static final String FLOW_CONTROL_TYPE = "streaming.flow_control_type"; + public static final String WRITER_CONSUMED_STEP = "streaming.writer.consumed_step"; + public static final String READER_CONSUMED_STEP = "streaming.reader.consumed_step"; +} diff --git a/streaming/java/streaming-api/src/main/resources/log4j.properties b/streaming/java/streaming-api/src/main/resources/log4j.properties new file mode 100644 index 00000000..30d876ae --- /dev/null +++ b/streaming/java/streaming-api/src/main/resources/log4j.properties @@ -0,0 +1,6 @@ +log4j.rootLogger=INFO, stdout +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n diff --git a/streaming/java/streaming-api/src/main/resources/ray.conf b/streaming/java/streaming-api/src/main/resources/ray.conf new file mode 100644 index 00000000..7dbf3674 --- /dev/null +++ b/streaming/java/streaming-api/src/main/resources/ray.conf @@ -0,0 +1,5 @@ +ray { + run-mode = SINGLE_PROCESS + resources = "CPU:4" + address = "" +} diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java new file mode 100644 index 00000000..34a9e44d --- /dev/null +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java @@ -0,0 +1,40 @@ +package io.ray.streaming.api.stream; + +import static org.testng.Assert.assertEquals; + +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.operator.impl.MapOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonKeyDataStream; +import org.testng.annotations.Test; + +@SuppressWarnings("unchecked") +public class StreamTest { + + @Test + public void testReferencedDataStream() { + DataStream dataStream = + new DataStream(StreamingContext.buildContext(), new MapOperator(value -> null)); + PythonDataStream pythonDataStream = dataStream.asPythonStream(); + DataStream javaStream = pythonDataStream.asJavaStream(); + assertEquals(dataStream.getId(), pythonDataStream.getId()); + assertEquals(dataStream.getId(), javaStream.getId()); + javaStream.setParallelism(10); + assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism()); + assertEquals(dataStream.getParallelism(), javaStream.getParallelism()); + } + + @Test + public void testReferencedKeyDataStream() { + DataStream dataStream = + new DataStream(StreamingContext.buildContext(), new MapOperator(value -> null)); + KeyDataStream keyDataStream = dataStream.keyBy(value -> null); + PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream(); + KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream(); + assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId()); + assertEquals(keyDataStream.getId(), javaKeyDataStream.getId()); + javaKeyDataStream.setParallelism(10); + assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism()); + assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism()); + } +} diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java new file mode 100644 index 00000000..d0eec654 --- /dev/null +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java @@ -0,0 +1,93 @@ +package io.ray.streaming.jobgraph; + +import com.google.common.collect.Lists; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.partition.impl.ForwardPartition; +import io.ray.streaming.api.partition.impl.KeyPartition; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.api.stream.StreamSink; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class JobGraphBuilderTest { + + private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilderTest.class); + + @Test + public void testDataSync() { + JobGraph jobGraph = buildDataSyncJobGraph(); + List jobVertexList = jobGraph.getJobVertices(); + List jobEdgeList = jobGraph.getJobEdges(); + + Assert.assertEquals(jobVertexList.size(), 2); + Assert.assertEquals(jobEdgeList.size(), 1); + + JobEdge jobEdge = jobEdgeList.get(0); + Assert.assertEquals(jobEdge.getPartition().getClass(), ForwardPartition.class); + + JobVertex sinkVertex = jobVertexList.get(1); + JobVertex sourceVertex = jobVertexList.get(0); + Assert.assertEquals(sinkVertex.getVertexType(), VertexType.SINK); + Assert.assertEquals(sourceVertex.getVertexType(), VertexType.SOURCE); + } + + public JobGraph buildDataSyncJobGraph() { + StreamingContext streamingContext = StreamingContext.buildContext(); + DataStream dataStream = + DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("a", "b", "c")); + StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); + JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); + + JobGraph jobGraph = jobGraphBuilder.build(); + return jobGraph; + } + + @Test + public void testKeyByJobGraph() { + JobGraph jobGraph = buildKeyByJobGraph(); + List jobVertexList = jobGraph.getJobVertices(); + List jobEdgeList = jobGraph.getJobEdges(); + + Assert.assertEquals(jobVertexList.size(), 3); + Assert.assertEquals(jobEdgeList.size(), 2); + + JobVertex source = jobVertexList.get(0); + JobVertex map = jobVertexList.get(1); + JobVertex sink = jobVertexList.get(2); + + Assert.assertEquals(source.getVertexType(), VertexType.SOURCE); + Assert.assertEquals(map.getVertexType(), VertexType.TRANSFORMATION); + Assert.assertEquals(sink.getVertexType(), VertexType.SINK); + + JobEdge keyBy2Sink = jobEdgeList.get(0); + JobEdge source2KeyBy = jobEdgeList.get(1); + + Assert.assertEquals(keyBy2Sink.getPartition().getClass(), KeyPartition.class); + Assert.assertEquals(source2KeyBy.getPartition().getClass(), ForwardPartition.class); + } + + public JobGraph buildKeyByJobGraph() { + StreamingContext streamingContext = StreamingContext.buildContext(); + DataStream dataStream = + DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("1", "2", "3", "4")); + StreamSink streamSink = dataStream.keyBy(x -> x).sink(x -> LOG.info(x)); + JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); + + JobGraph jobGraph = jobGraphBuilder.build(); + return jobGraph; + } + + @Test + public void testJobGraphViz() { + JobGraph jobGraph = buildKeyByJobGraph(); + jobGraph.generateDigraph(); + String diGraph = jobGraph.getDigraph(); + LOG.info(diGraph); + Assert.assertTrue(diGraph.contains("\"1-SourceOperatorImpl\" -> \"2-KeyByOperator\"")); + Assert.assertTrue(diGraph.contains("\"2-KeyByOperator\" -> \"3-SinkOperator\"")); + } +} diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java new file mode 100644 index 00000000..bc0854f9 --- /dev/null +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphOptimizerTest.java @@ -0,0 +1,72 @@ +package io.ray.streaming.jobgraph; + +import static org.testng.Assert.assertEquals; + +import com.google.common.collect.Lists; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.python.PythonFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.Test; + +public class JobGraphOptimizerTest { + + private static final Logger LOG = LoggerFactory.getLogger(JobGraphOptimizerTest.class); + + @Test + public void testOptimize() { + StreamingContext context = StreamingContext.buildContext(); + DataStream source1 = + DataStreamSource.fromCollection(context, Lists.newArrayList(1, 2, 3)); + DataStream source2 = + DataStreamSource.fromCollection(context, Lists.newArrayList("1", "2", "3")); + DataStream source3 = + DataStreamSource.fromCollection(context, Lists.newArrayList("2", "3", "4")); + source1 + .filter(x -> x > 1) + .map(String::valueOf) + .union(source2) + .join(source3) + .sink(x -> System.out.println("Sink " + x)); + JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build(); + LOG.info("Digraph {}", jobGraph.generateDigraph()); + assertEquals(jobGraph.getJobVertices().size(), 8); + + JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph); + JobGraph optimizedJobGraph = graphOptimizer.optimize(); + optimizedJobGraph.printJobGraph(); + LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph()); + assertEquals(optimizedJobGraph.getJobVertices().size(), 5); + } + + @Test + public void testOptimizeHybridStream() { + StreamingContext context = StreamingContext.buildContext(); + DataStream source1 = + DataStreamSource.fromCollection(context, Lists.newArrayList(1, 2, 3)); + DataStream source2 = + DataStreamSource.fromCollection(context, Lists.newArrayList("1", "2", "3")); + source1 + .asPythonStream() + .map(pyFunc(1)) + .filter(pyFunc(2)) + .union(source2.asPythonStream().filter(pyFunc(3)).map(pyFunc(4))) + .asJavaStream() + .sink(x -> System.out.println("Sink " + x)); + JobGraph jobGraph = new JobGraphBuilder(context.getStreamSinks()).build(); + LOG.info("Digraph {}", jobGraph.generateDigraph()); + assertEquals(jobGraph.getJobVertices().size(), 8); + + JobGraphOptimizer graphOptimizer = new JobGraphOptimizer(jobGraph); + JobGraph optimizedJobGraph = graphOptimizer.optimize(); + optimizedJobGraph.printJobGraph(); + LOG.info("Optimized graph {}", optimizedJobGraph.generateDigraph()); + assertEquals(optimizedJobGraph.getJobVertices().size(), 6); + } + + private PythonFunction pyFunc(int number) { + return new PythonFunction("module", "func" + number); + } +} diff --git a/streaming/java/streaming-api/src/test/resources/log4j.properties b/streaming/java/streaming-api/src/test/resources/log4j.properties new file mode 100644 index 00000000..30d876ae --- /dev/null +++ b/streaming/java/streaming-api/src/test/resources/log4j.properties @@ -0,0 +1,6 @@ +log4j.rootLogger=INFO, stdout +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n diff --git a/streaming/java/streaming-api/src/test/resources/ray.conf b/streaming/java/streaming-api/src/test/resources/ray.conf new file mode 100644 index 00000000..fdc897fa --- /dev/null +++ b/streaming/java/streaming-api/src/test/resources/ray.conf @@ -0,0 +1,3 @@ +ray { + run-mode = SINGLE_PROCESS +} diff --git a/streaming/java/streaming-runtime/pom_template.xml b/streaming/java/streaming-runtime/pom_template.xml new file mode 100644 index 00000000..e7c9a830 --- /dev/null +++ b/streaming/java/streaming-runtime/pom_template.xml @@ -0,0 +1,95 @@ + + {auto_gen_header} + + + ray-streaming + io.ray + 2.0.0-SNAPSHOT + + 4.0.0 + + streaming-runtime + ray streaming runtime + ray streaming runtime + jar + + ${basedir}/../../build/java + + + + release + + + release + true + + false + + + ${basedir} + + + + + + + io.ray + ray-api + ${project.version} + + + io.ray + ray-runtime + ${project.version} + + + io.ray + streaming-api + ${project.version} + + {generated_bzl_deps} + + + + + + src/main/resources + + + native_dependencies + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies-to-build + package + + copy-dependencies + + + ${basedir}/../../build/java + false + false + true + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.3.1 + + ${output.directory} + + + + + diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/client/JobClientImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/client/JobClientImpl.java new file mode 100644 index 00000000..68ca33ec --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/client/JobClientImpl.java @@ -0,0 +1,55 @@ +package io.ray.streaming.runtime.client; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.streaming.client.JobClient; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.config.global.CommonConfig; +import io.ray.streaming.runtime.master.JobMaster; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Job client: to submit job from api to runtime. */ +public class JobClientImpl implements JobClient { + + public static final Logger LOG = LoggerFactory.getLogger(JobClientImpl.class); + + private ActorHandle jobMasterActor; + + @Override + public void submit(JobGraph jobGraph, Map jobConfig) { + LOG.info( + "Submitting job [{}] with job graph [{}] and job config [{}].", + jobGraph.getJobName(), + jobGraph, + jobConfig); + Map resources = new HashMap<>(); + + // set job name and id at start + jobConfig.put(CommonConfig.JOB_ID, Ray.getRuntimeContext().getCurrentJobId().toString()); + jobConfig.put(CommonConfig.JOB_NAME, jobGraph.getJobName()); + + jobGraph.getJobConfig().putAll(jobConfig); + + // create job master actor + this.jobMasterActor = + Ray.actor(JobMaster::new, jobConfig).setResources(resources).setMaxRestarts(-1).remote(); + + try { + ObjectRef submitResult = + jobMasterActor.task(JobMaster::submitJob, jobMasterActor, jobGraph).remote(); + + if (submitResult.get()) { + LOG.info("Finish submitting job: {}.", jobGraph.getJobName()); + } else { + throw new RuntimeException("submitting job failed"); + } + } catch (Exception e) { + LOG.error("Failed to submit job: {}.", jobGraph.getJobName(), e); + throw new RuntimeException("submitting job failed", e); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/Config.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/Config.java new file mode 100644 index 00000000..d4bc3161 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/Config.java @@ -0,0 +1,6 @@ +package io.ray.streaming.runtime.config; + +import org.aeonbits.owner.Accessible; + +/** Basic config interface. */ +public interface Config extends org.aeonbits.owner.Config, Accessible {} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingConfig.java new file mode 100644 index 00000000..e300ef7b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingConfig.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.config; + +import java.io.Serializable; +import java.util.Map; + +/** Streaming config including general, master and worker part. */ +public class StreamingConfig implements Serializable { + + public StreamingMasterConfig masterConfig; + public StreamingWorkerConfig workerConfigTemplate; + + public StreamingConfig(final Map conf) { + masterConfig = new StreamingMasterConfig(conf); + workerConfigTemplate = new StreamingWorkerConfig(conf); + } + + public Map getMap() { + Map wholeConfigMap = masterConfig.configMap; + wholeConfigMap.putAll(workerConfigTemplate.configMap); + return wholeConfigMap; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java new file mode 100644 index 00000000..8497ac64 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java @@ -0,0 +1,90 @@ +package io.ray.streaming.runtime.config; + +import com.google.common.base.Preconditions; +import io.ray.streaming.runtime.config.global.CheckpointConfig; +import io.ray.streaming.runtime.config.global.CommonConfig; +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import io.ray.streaming.runtime.config.global.TransferConfig; +import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; +import org.aeonbits.owner.Config.DefaultValue; +import org.aeonbits.owner.Config.Key; +import org.aeonbits.owner.ConfigFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Streaming general config. May used by both JobMaster and JobWorker. */ +public class StreamingGlobalConfig implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(StreamingGlobalConfig.class); + public final CommonConfig commonConfig; + public final TransferConfig transferConfig; + public final Map configMap; + public CheckpointConfig checkpointConfig; + public ContextBackendConfig contextBackendConfig; + + public StreamingGlobalConfig(final Map conf) { + configMap = new HashMap<>(conf); + + commonConfig = ConfigFactory.create(CommonConfig.class, conf); + transferConfig = ConfigFactory.create(TransferConfig.class, conf); + checkpointConfig = ConfigFactory.create(CheckpointConfig.class, conf); + contextBackendConfig = ConfigFactory.create(ContextBackendConfig.class, conf); + globalConfig2Map(); + } + + @Override + public String toString() { + return configMap.toString(); + } + + private void globalConfig2Map() { + try { + configMap.putAll(config2Map(this.commonConfig)); + configMap.putAll(config2Map(this.transferConfig)); + } catch (Exception e) { + LOG.error("Couldn't convert global config to a map.", e); + } + } + + protected Map config2Map(org.aeonbits.owner.Config config) + throws ClassNotFoundException { + Map result = new HashMap<>(); + + Class proxyClazz = Class.forName(config.getClass().getName()); + Class[] proxyInterfaces = proxyClazz.getInterfaces(); + + Class configInterface = null; + for (Class proxyInterface : proxyInterfaces) { + if (Config.class.isAssignableFrom(proxyInterface)) { + configInterface = proxyInterface; + break; + } + } + Preconditions.checkArgument(configInterface != null, "Can not get config interface."); + Method[] methods = configInterface.getMethods(); + + for (Method method : methods) { + Key ownerKeyAnnotation = method.getAnnotation(Key.class); + String ownerKeyAnnotationValue; + if (ownerKeyAnnotation != null) { + ownerKeyAnnotationValue = ownerKeyAnnotation.value(); + Object value; + try { + value = method.invoke(config); + } catch (Exception e) { + LOG.warn( + "Can not get value by method invoking for config key: {}. " + + "So use default value instead.", + ownerKeyAnnotationValue); + String defaultValue = method.getAnnotation(DefaultValue.class).value(); + value = defaultValue; + } + result.put(ownerKeyAnnotationValue, value + ""); + } + } + return result; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingMasterConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingMasterConfig.java new file mode 100644 index 00000000..cd66c973 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingMasterConfig.java @@ -0,0 +1,23 @@ +package io.ray.streaming.runtime.config; + +import io.ray.streaming.runtime.config.master.ResourceConfig; +import io.ray.streaming.runtime.config.master.SchedulerConfig; +import java.util.Map; +import org.aeonbits.owner.ConfigFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Streaming job master config. */ +public class StreamingMasterConfig extends StreamingGlobalConfig { + + private static final Logger LOG = LoggerFactory.getLogger(StreamingMasterConfig.class); + + public ResourceConfig resourceConfig; + public SchedulerConfig schedulerConfig; + + public StreamingMasterConfig(final Map conf) { + super(conf); + this.resourceConfig = ConfigFactory.create(ResourceConfig.class, conf); + this.schedulerConfig = ConfigFactory.create(SchedulerConfig.class, conf); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingWorkerConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingWorkerConfig.java new file mode 100644 index 00000000..f644cd77 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingWorkerConfig.java @@ -0,0 +1,34 @@ +package io.ray.streaming.runtime.config; + +import io.ray.streaming.runtime.config.worker.WorkerInternalConfig; +import java.util.HashMap; +import java.util.Map; +import org.aeonbits.owner.ConfigFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Streaming job worker specified config. */ +public class StreamingWorkerConfig extends StreamingGlobalConfig { + + private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkerConfig.class); + + public WorkerInternalConfig workerInternalConfig; + + public StreamingWorkerConfig(final Map conf) { + super(conf); + workerInternalConfig = ConfigFactory.create(WorkerInternalConfig.class, conf); + + configMap.putAll(workerConfig2Map()); + } + + public Map workerConfig2Map() { + Map result = new HashMap<>(); + try { + result.putAll(config2Map(this.workerInternalConfig)); + } catch (Exception e) { + LOG.error("Worker config to map occur error.", e); + return null; + } + return result; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java new file mode 100644 index 00000000..0d57feba --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java @@ -0,0 +1,53 @@ +package io.ray.streaming.runtime.config.global; + +import io.ray.streaming.runtime.config.Config; +import org.aeonbits.owner.Mutable; + +/** Configurations for checkpointing. */ +public interface CheckpointConfig extends Config, Mutable { + + String CP_INTERVAL_SECS = "streaming.checkpoint.interval.secs"; + String CP_TIMEOUT_SECS = "streaming.checkpoint.timeout.secs"; + + String CP_PREFIX_KEY_MASTER = "streaming.checkpoint.prefix-key.job-master.context"; + String CP_PREFIX_KEY_WORKER = "streaming.checkpoint.prefix-key.job-worker.context"; + String CP_PREFIX_KEY_OPERATOR = "streaming.checkpoint.prefix-key.job-worker.operator"; + + /** + * Checkpoint time interval. JobMaster won't trigger 2 checkpoint in less than this time interval. + */ + @DefaultValue(value = "5") + @Key(value = CP_INTERVAL_SECS) + int cpIntervalSecs(); + + /** + * How long should JobMaster wait for checkpoint to finish. When this timeout is reached and + * JobMaster hasn't received all commits from workers, JobMaster will consider this checkpoint as + * failed and trigger another checkpoint. + */ + @DefaultValue(value = "120") + @Key(value = CP_TIMEOUT_SECS) + int cpTimeoutSecs(); + + /** + * This is used for saving JobMaster's context to storage, user usually don't need to change this. + */ + @DefaultValue(value = "job_master_runtime_context_") + @Key(value = CP_PREFIX_KEY_MASTER) + String jobMasterContextCpPrefixKey(); + + /** + * This is used for saving JobWorker's context to storage, user usually don't need to change this. + */ + @DefaultValue(value = "job_worker_context_") + @Key(value = CP_PREFIX_KEY_WORKER) + String jobWorkerContextCpPrefixKey(); + + /** + * This is used for saving user operator(in StreamTask)'s context to storage, user usually don't + * need to change this. + */ + @DefaultValue(value = "job_worker_op_") + @Key(value = CP_PREFIX_KEY_OPERATOR) + String jobWorkerOpCpPrefixKey(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CommonConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CommonConfig.java new file mode 100644 index 00000000..2ec3b6df --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CommonConfig.java @@ -0,0 +1,28 @@ +package io.ray.streaming.runtime.config.global; + +import io.ray.streaming.runtime.config.Config; + +/** Job common config. */ +public interface CommonConfig extends Config { + + String JOB_ID = "streaming.job.id"; + String JOB_NAME = "streaming.job.name"; + + /** + * Ray streaming job id. Non-custom. + * + * @return Job id with string type. + */ + @DefaultValue(value = "default-job-id") + @Key(value = JOB_ID) + String jobId(); + + /** + * Ray streaming job name. Non-custom. + * + * @return Job name with string type. + */ + @DefaultValue(value = "default-job-name") + @Key(value = JOB_NAME) + String jobName(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java new file mode 100644 index 00000000..11d1d337 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java @@ -0,0 +1,17 @@ +package io.ray.streaming.runtime.config.global; + +import org.aeonbits.owner.Config; + +public interface ContextBackendConfig extends Config { + + String STATE_BACKEND_TYPE = "streaming.context-backend.type"; + String FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root"; + + @Config.DefaultValue(value = "memory") + @Key(value = STATE_BACKEND_TYPE) + String stateBackendType(); + + @Config.DefaultValue(value = "/tmp/ray_streaming_state") + @Key(value = FILE_STATE_ROOT_PATH) + String fileStateRootPath(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java new file mode 100644 index 00000000..eb497809 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java @@ -0,0 +1,47 @@ +package io.ray.streaming.runtime.config.global; + +import io.ray.streaming.runtime.config.Config; +import io.ray.streaming.runtime.config.types.TransferChannelType; + +/** Job data transfer config. */ +public interface TransferConfig extends Config { + + /** Data transfer channel type, support memory queue and native queue. */ + @DefaultValue(value = "NATIVE_CHANNEL") + @Key(value = io.ray.streaming.util.Config.CHANNEL_TYPE) + TransferChannelType channelType(); + + /** Queue size. */ + @DefaultValue(value = "100000000") + @Key(value = io.ray.streaming.util.Config.CHANNEL_SIZE) + long channelSize(); + + /** Return from DataReader.getBundle if only empty message read in this interval. */ + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.TIMER_INTERVAL_MS) + long readerTimerIntervalMs(); + + /** Ring capacity. */ + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.STREAMING_RING_BUFFER_CAPACITY) + int ringBufferCapacity(); + + /** Write an empty message if there is no data to be written in this interval. */ + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.STREAMING_EMPTY_MESSAGE_INTERVAL) + int emptyMsgInterval(); + + // Flow control + + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.FLOW_CONTROL_TYPE) + int flowControlType(); + + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.WRITER_CONSUMED_STEP) + int writerConsumedStep(); + + @DefaultValue(value = "-1") + @Key(value = io.ray.streaming.util.Config.READER_CONSUMED_STEP) + int readerConsumedStep(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/ResourceConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/ResourceConfig.java new file mode 100644 index 00000000..21a18e29 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/ResourceConfig.java @@ -0,0 +1,55 @@ +package io.ray.streaming.runtime.config.master; + +import io.ray.streaming.runtime.config.Config; + +/** Job resource management config. */ +public interface ResourceConfig extends Config { + + /** Number of actors per container. */ + String MAX_ACTOR_NUM_PER_CONTAINER = "streaming.container.per.max.actor"; + + /** The interval between detecting ray cluster nodes. */ + String CONTAINER_RESOURCE_CHECk_INTERVAL_SECOND = "streaming.resource.check.interval.second"; + + /** CPU use by per task. */ + String TASK_RESOURCE_CPU = "streaming.task.resource.cpu"; + + /** Memory use by each task */ + String TASK_RESOURCE_MEM = "streaming.task.resource.mem"; + + /** Whether to enable CPU limit in resource control. */ + String TASK_RESOURCE_CPU_LIMIT_ENABLE = "streaming.task.resource.cpu.limitation.enable"; + + /** Whether to enable memory limit in resource control. */ + String TASK_RESOURCE_MEM_LIMIT_ENABLE = "streaming.task.resource.mem.limitation.enable"; + + /** Number of cpu per task. */ + @DefaultValue(value = "1.0") + @Key(value = TASK_RESOURCE_CPU) + double taskCpuResource(); + + /** Memory size used by each task. */ + @DefaultValue(value = "2.0") + @Key(value = TASK_RESOURCE_MEM) + double taskMemResource(); + + /** Whether to enable CPU limit in resource control. */ + @DefaultValue(value = "false") + @Key(value = TASK_RESOURCE_CPU_LIMIT_ENABLE) + boolean isTaskCpuResourceLimit(); + + /** Whether to enable memory limit in resource control. */ + @DefaultValue(value = "false") + @Key(value = TASK_RESOURCE_MEM_LIMIT_ENABLE) + boolean isTaskMemResourceLimit(); + + /** Number of actors per container. */ + @DefaultValue(value = "500") + @Key(MAX_ACTOR_NUM_PER_CONTAINER) + int actorNumPerContainer(); + + /** The interval between detecting ray cluster nodes. */ + @DefaultValue(value = "1") + @Key(value = CONTAINER_RESOURCE_CHECk_INTERVAL_SECOND) + long resourceCheckIntervalSecond(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/SchedulerConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/SchedulerConfig.java new file mode 100644 index 00000000..79189431 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/master/SchedulerConfig.java @@ -0,0 +1,28 @@ +package io.ray.streaming.runtime.config.master; + +import io.ray.streaming.runtime.config.Config; + +/** Configuration for job scheduler. */ +public interface SchedulerConfig extends Config { + + String WORKER_INITIATION_WAIT_TIMEOUT_MS = "streaming.scheduler.worker.initiation.timeout.ms"; + String WORKER_STARTING_WAIT_TIMEOUT_MS = "streaming.scheduler.worker.starting.timeout.ms"; + + /** + * The timeout ms of worker initiation. Default is: 10000ms(10s). + * + * @return timeout ms + */ + @Key(WORKER_INITIATION_WAIT_TIMEOUT_MS) + @DefaultValue(value = "10000") + int workerInitiationWaitTimeoutMs(); + + /** + * The timeout ms of worker starting. Default is: 10000ms(10s). + * + * @return timeout ms + */ + @Key(WORKER_STARTING_WAIT_TIMEOUT_MS) + @DefaultValue(value = "10000") + int workerStartingWaitTimeoutMs(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java new file mode 100644 index 00000000..a9856c85 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java @@ -0,0 +1,18 @@ +package io.ray.streaming.runtime.config.types; + +public enum ContextBackendType { + + /** Memory type */ + MEMORY("memory", 0), + + /** Local File */ + LOCAL_FILE("local_file", 1); + + private String name; + private int index; + + ContextBackendType(String name, int index) { + this.name = name; + this.index = index; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ResourceAssignStrategyType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ResourceAssignStrategyType.java new file mode 100644 index 00000000..05c2f657 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ResourceAssignStrategyType.java @@ -0,0 +1,19 @@ +package io.ray.streaming.runtime.config.types; + +public enum ResourceAssignStrategyType { + + /** Resource scheduling strategy based on FF(First Fit) algorithm and pipeline. */ + PIPELINE_FIRST_STRATEGY("pipeline_first_strategy", 0); + + private String name; + private int index; + + ResourceAssignStrategyType(String name, int index) { + this.name = name; + this.index = index; + } + + public String getName() { + return name; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/TransferChannelType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/TransferChannelType.java new file mode 100644 index 00000000..21436790 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/TransferChannelType.java @@ -0,0 +1,23 @@ +package io.ray.streaming.runtime.config.types; + +/** Data transfer channel type. */ +public enum TransferChannelType { + + /** Memory queue. */ + MEMORY_CHANNEL("memory_channel", 0), + + /** Native queue. */ + NATIVE_CHANNEL("native_channel", 1); + + private String value; + private int index; + + TransferChannelType(String value, int index) { + this.value = value; + this.index = index; + } + + public String getValue() { + return value; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/worker/WorkerInternalConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/worker/WorkerInternalConfig.java new file mode 100644 index 00000000..e2eb62cb --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/worker/WorkerInternalConfig.java @@ -0,0 +1,21 @@ +package io.ray.streaming.runtime.config.worker; + +import io.ray.streaming.runtime.config.Config; +import org.aeonbits.owner.Mutable; + +/** This worker config is used by JobMaster to define the internal configuration of JobWorker. */ +public interface WorkerInternalConfig extends Config, Mutable { + + String WORKER_NAME_INTERNAL = io.ray.streaming.util.Config.STREAMING_WORKER_NAME; + String OP_NAME_INTERNAL = io.ray.streaming.util.Config.STREAMING_OP_NAME; + + /** The name of the worker inside the system. */ + @DefaultValue(value = "default-worker-name") + @Key(value = WORKER_NAME_INTERNAL) + String workerName(); + + /** Operator name corresponding to worker. */ + @DefaultValue(value = "default-worker-op-name") + @Key(value = OP_NAME_INTERNAL) + String workerOperatorName(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java new file mode 100644 index 00000000..83b62696 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java @@ -0,0 +1,41 @@ +package io.ray.streaming.runtime.context; + +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.worker.JobWorker; + +/** + * This interface is used for storing context of {@link JobWorker} and {@link JobMaster}. The + * checkpoint returned by user function is also saved using this interface. + */ +public interface ContextBackend { + + /** + * check if key exists in state + * + * @return true if exists + */ + boolean exists(final String key) throws Exception; + + /** + * get content by key + * + * @param key key + * @return the StateBackend + */ + byte[] get(final String key) throws Exception; + + /** + * put content by key + * + * @param key key + * @param value content + */ + void put(final String key, final byte[] value) throws Exception; + + /** + * remove content by key + * + * @param key key + */ + void remove(final String key) throws Exception; +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java new file mode 100644 index 00000000..bb0af08a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java @@ -0,0 +1,27 @@ +package io.ray.streaming.runtime.context; + +import io.ray.streaming.runtime.config.StreamingGlobalConfig; +import io.ray.streaming.runtime.config.types.ContextBackendType; +import io.ray.streaming.runtime.context.impl.AtomicFsBackend; +import io.ray.streaming.runtime.context.impl.MemoryContextBackend; + +public class ContextBackendFactory { + + public static ContextBackend getContextBackend(final StreamingGlobalConfig config) { + ContextBackend contextBackend; + ContextBackendType type = + ContextBackendType.valueOf(config.contextBackendConfig.stateBackendType().toUpperCase()); + + switch (type) { + case MEMORY: + contextBackend = new MemoryContextBackend(config.contextBackendConfig); + break; + case LOCAL_FILE: + contextBackend = new AtomicFsBackend(config.contextBackendConfig); + break; + default: + throw new RuntimeException("Unsupported context backend type."); + } + return contextBackend; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java new file mode 100644 index 00000000..644e2d33 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java @@ -0,0 +1,48 @@ +package io.ray.streaming.runtime.context; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** This data structure contains state information of a task. */ +public class OperatorCheckpointInfo implements Serializable { + + /** key: channel ID, value: offset */ + public Map inputPoints; + + public Map outputPoints; + + /** a serializable checkpoint returned by processor */ + public Serializable processorCheckpoint; + + public long checkpointId; + + public OperatorCheckpointInfo() { + inputPoints = new HashMap<>(); + outputPoints = new HashMap<>(); + checkpointId = -1; + } + + public OperatorCheckpointInfo( + Map inputPoints, + Map outputPoints, + Serializable processorCheckpoint, + long checkpointId) { + this.inputPoints = inputPoints; + this.outputPoints = outputPoints; + this.checkpointId = checkpointId; + this.processorCheckpoint = processorCheckpoint; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("inputPoints", inputPoints) + .add("outputPoints", outputPoints) + .add("processorCheckpoint", processorCheckpoint) + .add("checkpointId", checkpointId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java new file mode 100644 index 00000000..3baf1578 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java @@ -0,0 +1,48 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Achieves an atomic `put` method. known issue: if you crashed while write a key at first time, + * this code will not work. + */ +public class AtomicFsBackend extends LocalFileContextBackend { + + private static final Logger LOG = LoggerFactory.getLogger(AtomicFsBackend.class); + private static final String TMP_FLAG = "_tmp"; + + public AtomicFsBackend(final ContextBackendConfig config) { + super(config); + } + + @Override + public byte[] get(String key) throws Exception { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey) && !super.exists(key)) { + return super.get(tmpKey); + } + return super.get(key); + } + + @Override + public void put(String key, byte[] value) throws Exception { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey) && !super.exists(key)) { + super.rename(tmpKey, key); + } + super.put(tmpKey, value); + super.remove(key); + super.rename(tmpKey, key); + } + + @Override + public void remove(String key) { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey)) { + super.remove(tmpKey); + } + super.remove(key); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java new file mode 100644 index 00000000..cd700590 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java @@ -0,0 +1,54 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import java.io.File; +import org.apache.commons.io.FileUtils; + +/** + * This context backend uses local file system and doesn't supports failover in cluster. But it + * supports failover in single node. This is a pure file system backend which doesn't support atomic + * writing, please don't use this class, instead, use {@link AtomicFsBackend} which extends this + * class. + */ +public class LocalFileContextBackend implements ContextBackend { + + private final String rootPath; + + public LocalFileContextBackend(ContextBackendConfig config) { + rootPath = config.fileStateRootPath(); + } + + @Override + public boolean exists(String key) { + File file = new File(rootPath, key); + return file.exists(); + } + + @Override + public byte[] get(String key) throws Exception { + File file = new File(rootPath, key); + if (file.exists()) { + return FileUtils.readFileToByteArray(file); + } + return null; + } + + @Override + public void put(String key, byte[] value) throws Exception { + File file = new File(rootPath, key); + FileUtils.writeByteArrayToFile(file, value); + } + + @Override + public void remove(String key) { + File file = new File(rootPath, key); + FileUtils.deleteQuietly(file); + } + + protected void rename(String fromKey, String toKey) throws Exception { + File srcFile = new File(rootPath, fromKey); + File dstFile = new File(rootPath, toKey); + FileUtils.moveFile(srcFile, dstFile); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java new file mode 100644 index 00000000..eb9ea153 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java @@ -0,0 +1,72 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This context backend uses memory and doesn't supports failover. Data will be lost after worker + * died. + */ +public class MemoryContextBackend implements ContextBackend { + + private static final Logger LOG = LoggerFactory.getLogger(MemoryContextBackend.class); + + private final Map kvStore = new HashMap<>(); + + public MemoryContextBackend(ContextBackendConfig config) { + if (LOG.isInfoEnabled()) { + LOG.info("Start init memory state backend, config is {}.", config); + LOG.info("Finish init memory state backend."); + } + } + + @Override + public boolean exists(String key) { + return kvStore.containsKey(key); + } + + @Override + public byte[] get(final String key) { + if (LOG.isInfoEnabled()) { + LOG.info("Get value of key {} start.", key); + } + + byte[] readData = kvStore.get(key); + + if (LOG.isInfoEnabled()) { + LOG.info("Get value of key {} success.", key); + } + + return readData; + } + + @Override + public void put(final String key, final byte[] value) { + if (LOG.isInfoEnabled()) { + LOG.info("Put value of key {} start.", key); + } + + kvStore.put(key, value); + + if (LOG.isInfoEnabled()) { + LOG.info("Put value of key {} success.", key); + } + } + + @Override + public void remove(final String key) { + if (LOG.isInfoEnabled()) { + LOG.info("Remove value of key {} start.", key); + } + + kvStore.remove(key); + + if (LOG.isInfoEnabled()) { + LOG.info("Remove value of key {} success.", key); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java new file mode 100644 index 00000000..877f3c5b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java @@ -0,0 +1,81 @@ +package io.ray.streaming.runtime.core.collector; + +import io.ray.api.BaseActorHandle; +import io.ray.api.PyActorHandle; +import io.ray.streaming.api.Language; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.message.Record; +import io.ray.streaming.runtime.serialization.CrossLangSerializer; +import io.ray.streaming.runtime.serialization.JavaSerializer; +import io.ray.streaming.runtime.serialization.Serializer; +import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import java.nio.ByteBuffer; +import java.util.Collection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OutputCollector implements Collector { + + private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class); + + private final DataWriter writer; + private final ChannelId[] outputQueues; + private final Collection targetActors; + private final Language[] targetLanguages; + private final Partition partition; + private final Serializer javaSerializer = new JavaSerializer(); + private final Serializer crossLangSerializer = new CrossLangSerializer(); + + public OutputCollector( + DataWriter writer, + Collection outputChannelIds, + Collection targetActors, + Partition partition) { + this.writer = writer; + this.outputQueues = outputChannelIds.stream().map(ChannelId::from).toArray(ChannelId[]::new); + this.targetActors = targetActors; + this.targetLanguages = + targetActors.stream() + .map(actor -> actor instanceof PyActorHandle ? Language.PYTHON : Language.JAVA) + .toArray(Language[]::new); + this.partition = partition; + LOGGER.debug( + "OutputCollector constructed, outputChannelIds:{}, partition:{}.", + outputChannelIds, + this.partition); + } + + @Override + public void collect(Record record) { + int[] partitions = this.partition.partition(record, outputQueues.length); + ByteBuffer javaBuffer = null; + ByteBuffer crossLangBuffer = null; + for (int partition : partitions) { + if (targetLanguages[partition] == Language.JAVA) { + // avoid repeated serialization + if (javaBuffer == null) { + byte[] bytes = javaSerializer.serialize(record); + javaBuffer = ByteBuffer.allocate(1 + bytes.length); + javaBuffer.put(Serializer.JAVA_TYPE_ID); + // TODO(chaokunyang) remove copy + javaBuffer.put(bytes); + javaBuffer.flip(); + } + writer.write(outputQueues[partition], javaBuffer.duplicate()); + } else { + // avoid repeated serialization + if (crossLangBuffer == null) { + byte[] bytes = crossLangSerializer.serialize(record); + crossLangBuffer = ByteBuffer.allocate(1 + bytes.length); + crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID); + // TODO(chaokunyang) remove copy + crossLangBuffer.put(bytes); + crossLangBuffer.flip(); + } + writer.write(outputQueues[partition], crossLangBuffer.duplicate()); + } + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/command/BatchInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/command/BatchInfo.java new file mode 100644 index 00000000..5291a586 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/command/BatchInfo.java @@ -0,0 +1,20 @@ +package io.ray.streaming.runtime.core.command; + +import java.io.Serializable; + +public class BatchInfo implements Serializable { + + private long batchId; + + public BatchInfo(long batchId) { + this.batchId = batchId; + } + + public long getBatchId() { + return batchId; + } + + public void setBatchId(long batchId) { + this.batchId = batchId; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/common/AbstractId.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/common/AbstractId.java new file mode 100644 index 00000000..d7c7e465 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/common/AbstractId.java @@ -0,0 +1,30 @@ +package io.ray.streaming.runtime.core.common; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.runtime.core.resource.ContainerId; +import java.io.Serializable; +import java.util.UUID; + +/** Streaming system unique identity base class. For example, ${@link ContainerId } */ +public class AbstractId implements Serializable { + + private UUID id; + + public AbstractId() { + this.id = UUID.randomUUID(); + } + + @Override + public boolean equals(Object obj) { + return id.equals(((AbstractId) obj).getId()); + } + + public UUID getId() { + return id; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("id", id).toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionEdge.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionEdge.java new file mode 100644 index 00000000..3903f8ad --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionEdge.java @@ -0,0 +1,75 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.api.partition.Partition; +import java.io.Serializable; + +/** An edge that connects two execution vertices. */ +public class ExecutionEdge implements Serializable { + + /** The source(upstream) execution vertex. */ + private final ExecutionVertex sourceExecutionVertex; + + /** The target(downstream) execution vertex. */ + private final ExecutionVertex targetExecutionVertex; + + /** The partition of current execution edge's execution job edge. */ + private final Partition partition; + + /** An unique id for execution edge. */ + private final String executionEdgeIndex; + + public ExecutionEdge( + ExecutionVertex sourceExecutionVertex, + ExecutionVertex targetExecutionVertex, + ExecutionJobEdge executionJobEdge) { + this.sourceExecutionVertex = sourceExecutionVertex; + this.targetExecutionVertex = targetExecutionVertex; + this.partition = executionJobEdge.getPartition(); + this.executionEdgeIndex = generateExecutionEdgeIndex(); + } + + private String generateExecutionEdgeIndex() { + return sourceExecutionVertex.getExecutionVertexId() + + "—" + + targetExecutionVertex.getExecutionVertexId(); + } + + public ExecutionVertex getSourceExecutionVertex() { + return sourceExecutionVertex; + } + + public ExecutionVertex getTargetExecutionVertex() { + return targetExecutionVertex; + } + + public String getTargetExecutionJobVertexName() { + return getTargetExecutionVertex().getExecutionJobVertexName(); + } + + public int getSourceVertexId() { + return sourceExecutionVertex.getExecutionVertexId(); + } + + public int getTargetVertexId() { + return targetExecutionVertex.getExecutionVertexId(); + } + + public Partition getPartition() { + return partition; + } + + public String getExecutionEdgeIndex() { + return executionEdgeIndex; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("source", sourceExecutionVertex) + .add("target", targetExecutionVertex) + .add("partition", partition) + .add("index", executionEdgeIndex) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java new file mode 100644 index 00000000..2852e0f9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java @@ -0,0 +1,332 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import com.google.common.collect.Sets; +import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Physical plan. */ +public class ExecutionGraph implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(ExecutionGraph.class); + + /** Name of the job. */ + private final String jobName; + + /** Configuration of the job. */ + private Map jobConfig; + + /** Data map for execution job vertex. key: job vertex id. value: execution job vertex. */ + private Map executionJobVertexMap; + + /** Data map for execution vertex. key: execution vertex id. value: execution vertex. */ + private Map executionVertexMap; + + /** Data map for execution vertex. key: actor id. value: execution vertex. */ + private Map actorIdExecutionVertexMap; + + /** key: channel ID value: actors in both sides of this channel */ + private Map> channelGroupedActors; + + /** The max parallelism of the whole graph. */ + private int maxParallelism; + + /** Build time. */ + private long buildTime; + + /** A monotonic increasing number, used for vertex's id(immutable). */ + private AtomicInteger executionVertexIdGenerator = new AtomicInteger(0); + + public ExecutionGraph(String jobName) { + this.jobName = jobName; + this.buildTime = System.currentTimeMillis(); + } + + public String getJobName() { + return jobName; + } + + public List getExecutionJobVertexList() { + return new ArrayList<>(executionJobVertexMap.values()); + } + + public Map getExecutionJobVertexMap() { + return executionJobVertexMap; + } + + public void setExecutionJobVertexMap(Map executionJobVertexMap) { + this.executionJobVertexMap = executionJobVertexMap; + } + + /** + * generate relation mappings between actors, execution vertices and channels this method must be + * called after worker actor is set. + */ + public void generateActorMappings() { + LOG.info("Setup queue actors relation."); + + channelGroupedActors = new HashMap<>(); + actorIdExecutionVertexMap = new HashMap<>(); + + getAllExecutionVertices() + .forEach( + curVertex -> { + + // current + actorIdExecutionVertexMap.put(curVertex.getActorId(), curVertex); + + // input + List inputEdges = curVertex.getInputEdges(); + inputEdges.forEach( + inputEdge -> { + ExecutionVertex inputVertex = inputEdge.getSourceExecutionVertex(); + String channelId = curVertex.getChannelIdByPeerVertex(inputVertex); + addActorToChannelGroupedActors( + channelGroupedActors, channelId, inputVertex.getWorkerActor()); + }); + + // output + List outputEdges = curVertex.getOutputEdges(); + outputEdges.forEach( + outputEdge -> { + ExecutionVertex outputVertex = outputEdge.getTargetExecutionVertex(); + String channelId = curVertex.getChannelIdByPeerVertex(outputVertex); + addActorToChannelGroupedActors( + channelGroupedActors, channelId, outputVertex.getWorkerActor()); + }); + }); + + LOG.debug("Channel grouped actors is: {}.", channelGroupedActors); + } + + private void addActorToChannelGroupedActors( + Map> channelGroupedActors, + String queueName, + BaseActorHandle actor) { + + Set actorSet = + channelGroupedActors.computeIfAbsent(queueName, k -> new HashSet<>()); + actorSet.add(actor); + } + + public void setExecutionVertexMap(Map executionVertexMap) { + this.executionVertexMap = executionVertexMap; + } + + public Map getJobConfig() { + return jobConfig; + } + + public void setJobConfig(Map jobConfig) { + this.jobConfig = jobConfig; + } + + public int getMaxParallelism() { + return maxParallelism; + } + + public void setMaxParallelism(int maxParallelism) { + this.maxParallelism = maxParallelism; + } + + public long getBuildTime() { + return buildTime; + } + + public int generateExecutionVertexId() { + return executionVertexIdGenerator.getAndIncrement(); + } + + public AtomicInteger getExecutionVertexIdGenerator() { + return executionVertexIdGenerator; + } + + /** + * Get all execution vertices from current execution graph. + * + * @return all execution vertices. + */ + public List getAllExecutionVertices() { + return executionJobVertexMap.values().stream() + .map(ExecutionJobVertex::getExecutionVertices) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + /** + * Get all execution vertices whose status is 'TO_ADD' from current execution graph. + * + * @return all added execution vertices. + */ + public List getAllAddedExecutionVertices() { + return executionJobVertexMap.values().stream() + .map(ExecutionJobVertex::getExecutionVertices) + .flatMap(Collection::stream) + .filter(ExecutionVertex::is2Add) + .collect(Collectors.toList()); + } + + /** + * Get specified execution vertex from current execution graph by execution vertex id. + * + * @param executionVertexId execution vertex id. + * @return the specified execution vertex. + */ + public ExecutionVertex getExecutionVertexByExecutionVertexId(int executionVertexId) { + if (executionVertexMap.containsKey(executionVertexId)) { + return executionVertexMap.get(executionVertexId); + } + throw new RuntimeException("Vertex " + executionVertexId + " does not exist!"); + } + + /** + * Get specified execution vertex from current execution graph by actor id. + * + * @param actorId the actor id of execution vertex. + * @return the specified execution vertex. + */ + public ExecutionVertex getExecutionVertexByActorId(ActorId actorId) { + return actorIdExecutionVertexMap.get(actorId); + } + + /** + * Get specified actor by actor id. + * + * @param actorId the actor id of execution vertex. + * @return the specified actor handle. + */ + public Optional getActorById(ActorId actorId) { + return getAllActors().stream().filter(actor -> actor.getId().equals(actorId)).findFirst(); + } + + /** + * Get the peer actor in the other side of channelName of a given actor + * + * @param actor actor in this side + * @param channelName the channel name + * @return the peer actor in the other side + */ + public BaseActorHandle getPeerActor(BaseActorHandle actor, String channelName) { + Set set = getActorsByChannelId(channelName); + final BaseActorHandle[] res = new BaseActorHandle[1]; + set.forEach( + anActor -> { + if (!anActor.equals(actor)) { + res[0] = anActor; + } + }); + return res[0]; + } + + /** + * Get actors in both sides of a channelId + * + * @param channelId the channelId + * @return actors in both sides + */ + public Set getActorsByChannelId(String channelId) { + return channelGroupedActors.getOrDefault(channelId, Sets.newHashSet()); + } + + /** + * Get all actors by graph. + * + * @return actor list + */ + public List getAllActors() { + return getActorsFromJobVertices(getExecutionJobVertexList()); + } + + /** + * Get source actors by graph. + * + * @return actor list + */ + public List getSourceActors() { + List executionJobVertices = + getExecutionJobVertexList().stream() + .filter(ExecutionJobVertex::isSourceVertex) + .collect(Collectors.toList()); + + return getActorsFromJobVertices(executionJobVertices); + } + + /** + * Get transformation and sink actors by graph. + * + * @return actor list + */ + public List getNonSourceActors() { + List executionJobVertices = + getExecutionJobVertexList().stream() + .filter( + executionJobVertex -> + executionJobVertex.isTransformationVertex() + || executionJobVertex.isSinkVertex()) + .collect(Collectors.toList()); + + return getActorsFromJobVertices(executionJobVertices); + } + + /** + * Get sink actors by graph. + * + * @return actor list + */ + public List getSinkActors() { + List executionJobVertices = + getExecutionJobVertexList().stream() + .filter(ExecutionJobVertex::isSinkVertex) + .collect(Collectors.toList()); + + return getActorsFromJobVertices(executionJobVertices); + } + + /** + * Get actors according to job vertices. + * + * @param executionJobVertices specified job vertices + * @return actor list + */ + public List getActorsFromJobVertices( + List executionJobVertices) { + return executionJobVertices.stream() + .map(ExecutionJobVertex::getExecutionVertices) + .flatMap(Collection::stream) + .map(ExecutionVertex::getWorkerActor) + .collect(Collectors.toList()); + } + + public Set getActorName(Set actorIds) { + return getAllExecutionVertices().stream() + .filter(executionVertex -> actorIds.contains(executionVertex.getActorId())) + .map(ExecutionVertex::getActorName) + .collect(Collectors.toSet()); + } + + public String getActorName(ActorId actorId) { + Set set = Sets.newHashSet(); + set.add(actorId); + Set result = getActorName(set); + if (result.isEmpty()) { + return null; + } + return result.iterator().next(); + } + + public List getAllActorsId() { + return getAllActors().stream().map(BaseActorHandle::getId).collect(Collectors.toList()); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java new file mode 100644 index 00000000..72a6a86b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java @@ -0,0 +1,60 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.jobgraph.JobEdge; +import java.io.Serializable; + +/** An edge that connects two execution job vertices. */ +public class ExecutionJobEdge implements Serializable { + + /** The source(upstream) execution job vertex. */ + private final ExecutionJobVertex sourceExecutionJobVertex; + + /** The target(downstream) execution job vertex. */ + private final ExecutionJobVertex targetExecutionJobVertex; + + /** The partition of the execution job edge. */ + private final Partition partition; + + /** An unique id for execution job edge. */ + private final String executionJobEdgeIndex; + + public ExecutionJobEdge( + ExecutionJobVertex sourceExecutionJobVertex, + ExecutionJobVertex targetExecutionJobVertex, + JobEdge jobEdge) { + this.sourceExecutionJobVertex = sourceExecutionJobVertex; + this.targetExecutionJobVertex = targetExecutionJobVertex; + this.partition = jobEdge.getPartition(); + this.executionJobEdgeIndex = generateExecutionJobEdgeIndex(); + } + + private String generateExecutionJobEdgeIndex() { + return sourceExecutionJobVertex.getExecutionJobVertexId() + + "—" + + targetExecutionJobVertex.getExecutionJobVertexId(); + } + + public ExecutionJobVertex getSourceExecutionJobVertex() { + return sourceExecutionJobVertex; + } + + public ExecutionJobVertex getTargetExecutionJobVertex() { + return targetExecutionJobVertex; + } + + public Partition getPartition() { + return partition; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("source", sourceExecutionJobVertex) + .add("target", targetExecutionJobVertex) + .add("partition", partition) + .add("index", executionJobEdgeIndex) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java new file mode 100644 index 00000000..cf869c0c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java @@ -0,0 +1,187 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.streaming.api.Language; +import io.ray.streaming.jobgraph.JobVertex; +import io.ray.streaming.jobgraph.VertexType; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.runtime.config.master.ResourceConfig; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.aeonbits.owner.ConfigFactory; + +/** + * Physical job vertex. + * + *

Execution job vertex is the physical form of {@link JobVertex} and every execution job vertex + * is corresponding to a group of {@link ExecutionVertex}. + */ +public class ExecutionJobVertex implements Serializable { + + /** Unique id. Use {@link JobVertex}'s id directly. */ + private final int executionJobVertexId; + + /** + * Use jobVertex id and operator(use {@link StreamOperator}'s name) as name. e.g. 1-SourceOperator + */ + private final String executionJobVertexName; + + private final StreamOperator streamOperator; + private final VertexType vertexType; + private final Language language; + private final Map jobConfig; + private final long buildTime; + + /** Parallelism of current execution job vertex(operator). */ + private int parallelism; + + /** Sub execution vertices of current execution job vertex(operator). */ + private List executionVertices; + + /** Input and output edges of current execution job vertex. */ + private List inputEdges = new ArrayList<>(); + + private List outputEdges = new ArrayList<>(); + + public ExecutionJobVertex( + JobVertex jobVertex, + Map jobConfig, + AtomicInteger idGenerator, + long buildTime) { + this.executionJobVertexId = jobVertex.getVertexId(); + this.executionJobVertexName = + generateExecutionJobVertexName( + executionJobVertexId, jobVertex.getStreamOperator().getName()); + this.streamOperator = jobVertex.getStreamOperator(); + this.vertexType = jobVertex.getVertexType(); + this.language = jobVertex.getLanguage(); + this.jobConfig = jobConfig; + this.buildTime = buildTime; + this.parallelism = jobVertex.getParallelism(); + this.executionVertices = createExecutionVertices(idGenerator); + } + + private List createExecutionVertices(AtomicInteger idGenerator) { + List executionVertices = new ArrayList<>(); + ResourceConfig resourceConfig = ConfigFactory.create(ResourceConfig.class, jobConfig); + + for (int subIndex = 0; subIndex < parallelism; subIndex++) { + executionVertices.add( + new ExecutionVertex(idGenerator.getAndIncrement(), subIndex, this, resourceConfig)); + } + return executionVertices; + } + + private String generateExecutionJobVertexName(int jobVertexId, String streamOperatorName) { + return jobVertexId + "-" + streamOperatorName; + } + + public Map getExecutionVertexWorkers() { + Map executionVertexWorkersMap = new HashMap<>(); + + Preconditions.checkArgument( + executionVertices != null && !executionVertices.isEmpty(), "Empty execution vertex."); + executionVertices.stream() + .forEach( + vertex -> { + Preconditions.checkArgument( + vertex.getWorkerActor() != null, "Empty execution vertex worker actor."); + executionVertexWorkersMap.put(vertex.getExecutionVertexId(), vertex.getWorkerActor()); + }); + + return executionVertexWorkersMap; + } + + public int getExecutionJobVertexId() { + return executionJobVertexId; + } + + public String getExecutionJobVertexName() { + return executionJobVertexName; + } + + /** + * e.g. 1-SourceOperator + * + * @return operator name with index + */ + public String getExecutionJobVertexNameWithIndex() { + return executionJobVertexId + "-" + executionJobVertexName; + } + + public int getParallelism() { + return parallelism; + } + + public List getExecutionVertices() { + return executionVertices; + } + + public void setExecutionVertices(List executionVertex) { + this.executionVertices = executionVertex; + } + + public List getOutputEdges() { + return outputEdges; + } + + public void setOutputEdges(List outputEdges) { + this.outputEdges = outputEdges; + } + + public List getInputEdges() { + return inputEdges; + } + + public void setInputEdges(List inputEdges) { + this.inputEdges = inputEdges; + } + + public StreamOperator getStreamOperator() { + return streamOperator; + } + + public VertexType getVertexType() { + return vertexType; + } + + public Language getLanguage() { + return language; + } + + public Map getJobConfig() { + return jobConfig; + } + + public long getBuildTime() { + return buildTime; + } + + public boolean isSourceVertex() { + return getVertexType() == VertexType.SOURCE; + } + + public boolean isTransformationVertex() { + return getVertexType() == VertexType.TRANSFORMATION; + } + + public boolean isSinkVertex() { + return getVertexType() == VertexType.SINK; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("executionJobVertexId", executionJobVertexId) + .add("executionJobVertexName", executionJobVertexName) + .add("vertexType", vertexType) + .add("parallelism", parallelism) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java new file mode 100644 index 00000000..f53e9e5d --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java @@ -0,0 +1,338 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import com.google.common.base.MoreObjects; +import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; +import io.ray.streaming.api.Language; +import io.ray.streaming.jobgraph.VertexType; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.runtime.config.master.ResourceConfig; +import io.ray.streaming.runtime.core.resource.ContainerId; +import io.ray.streaming.runtime.core.resource.ResourceType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** Physical vertex, correspond to {@link ExecutionJobVertex}. */ +public class ExecutionVertex implements Serializable { + + /** Unique id for execution vertex. */ + private final int executionVertexId; + + /** Immutable field inherited from {@link ExecutionJobVertex}. */ + private final int executionJobVertexId; + + private final String executionJobVertexName; + private final StreamOperator streamOperator; + private final VertexType vertexType; + private final Language language; + private final long buildTime; + + /** Resource used by ExecutionVertex. */ + private final Map resource; + + /** Parallelism of current vertex's operator. */ + private int parallelism; + + /** + * Ordered sub index for execution vertex in a execution job vertex. Might be changed in dynamic + * scheduling. + */ + private int executionVertexIndex; + + private ExecutionVertexState state = ExecutionVertexState.TO_ADD; + + /** The id of the container which this vertex's worker actor belongs to. */ + private ContainerId containerId; + + private String pid; + + /** Worker actor handle. */ + private BaseActorHandle workerActor; + + /** Op config + job config. */ + private Map workerConfig; + + private List inputEdges = new ArrayList<>(); + private List outputEdges = new ArrayList<>(); + + private transient List outputChannelIdList; + private transient List inputChannelIdList; + + private transient List outputActorList; + private transient List inputActorList; + private Map exeVertexChannelMap; + + public ExecutionVertex( + int globalIndex, + int index, + ExecutionJobVertex executionJobVertex, + ResourceConfig resourceConfig) { + this.executionVertexId = globalIndex; + this.executionJobVertexId = executionJobVertex.getExecutionJobVertexId(); + this.executionJobVertexName = executionJobVertex.getExecutionJobVertexName(); + this.streamOperator = executionJobVertex.getStreamOperator(); + this.vertexType = executionJobVertex.getVertexType(); + this.language = executionJobVertex.getLanguage(); + this.buildTime = executionJobVertex.getBuildTime(); + this.parallelism = executionJobVertex.getParallelism(); + this.executionVertexIndex = index; + this.resource = generateResources(resourceConfig); + this.workerConfig = genWorkerConfig(executionJobVertex.getJobConfig()); + } + + private Map genWorkerConfig(Map jobConfig) { + return new HashMap<>(jobConfig); + } + + public int getExecutionVertexId() { + return executionVertexId; + } + + /** + * Unique name generated by execution job vertex name and index of current execution vertex. e.g. + * 1-SourceOperator-3 (vertex index is 3) + */ + public String getExecutionVertexName() { + return executionJobVertexName + "-" + executionVertexIndex; + } + + public int getExecutionJobVertexId() { + return executionJobVertexId; + } + + public String getExecutionJobVertexName() { + return executionJobVertexName; + } + + public StreamOperator getStreamOperator() { + return streamOperator; + } + + public VertexType getVertexType() { + return vertexType; + } + + public Language getLanguage() { + return language; + } + + public int getParallelism() { + return parallelism; + } + + public int getExecutionVertexIndex() { + return executionVertexIndex; + } + + public ExecutionVertexState getState() { + return state; + } + + public void setState(ExecutionVertexState state) { + this.state = state; + } + + public boolean is2Add() { + return state == ExecutionVertexState.TO_ADD; + } + + public boolean isRunning() { + return state == ExecutionVertexState.RUNNING; + } + + public boolean is2Delete() { + return state == ExecutionVertexState.TO_DEL; + } + + public BaseActorHandle getWorkerActor() { + return workerActor; + } + + public void setWorkerActor(BaseActorHandle workerActor) { + this.workerActor = workerActor; + } + + public ActorId getWorkerActorId() { + return workerActor.getId(); + } + + public List getInputEdges() { + return inputEdges; + } + + public void setInputEdges(List inputEdges) { + this.inputEdges = inputEdges; + } + + public List getOutputEdges() { + return outputEdges; + } + + public void setOutputEdges(List outputEdges) { + this.outputEdges = outputEdges; + } + + public List getInputVertices() { + return inputEdges.stream() + .map(ExecutionEdge::getSourceExecutionVertex) + .collect(Collectors.toList()); + } + + public List getOutputVertices() { + return outputEdges.stream() + .map(ExecutionEdge::getTargetExecutionVertex) + .collect(Collectors.toList()); + } + + public ActorId getActorId() { + return null == workerActor ? null : workerActor.getId(); + } + + public String getActorName() { + return String.valueOf(executionVertexId); + } + + public Map getResource() { + return resource; + } + + public Map getWorkerConfig() { + return workerConfig; + } + + public long getBuildTime() { + return buildTime; + } + + public ContainerId getContainerId() { + return containerId; + } + + public void setContainerId(ContainerId containerId) { + this.containerId = containerId; + } + + public String getPid() { + return pid; + } + + public void setPid(String pid) { + this.pid = pid; + } + + public void setContainerIfNotExist(ContainerId containerId) { + if (null == this.containerId) { + this.containerId = containerId; + } + } + + /*---------channel-actor relations---------*/ + public List getOutputChannelIdList() { + if (outputChannelIdList == null) { + generateActorChannelInfo(); + } + return outputChannelIdList; + } + + public List getOutputActorList() { + if (outputActorList == null) { + generateActorChannelInfo(); + } + return outputActorList; + } + + public List getInputChannelIdList() { + if (inputChannelIdList == null) { + generateActorChannelInfo(); + } + return inputChannelIdList; + } + + public List getInputActorList() { + if (inputActorList == null) { + generateActorChannelInfo(); + } + return inputActorList; + } + + public String getChannelIdByPeerVertex(ExecutionVertex peerVertex) { + if (exeVertexChannelMap == null) { + generateActorChannelInfo(); + } + return exeVertexChannelMap.get(peerVertex.getExecutionVertexId()); + } + + private void generateActorChannelInfo() { + inputChannelIdList = new ArrayList<>(); + inputActorList = new ArrayList<>(); + outputChannelIdList = new ArrayList<>(); + outputActorList = new ArrayList<>(); + exeVertexChannelMap = new HashMap<>(); + + List inputEdges = getInputEdges(); + for (ExecutionEdge edge : inputEdges) { + String channelId = + ChannelId.genIdStr( + edge.getSourceExecutionVertex().getExecutionVertexId(), + getExecutionVertexId(), + getBuildTime()); + inputChannelIdList.add(channelId); + inputActorList.add(edge.getSourceExecutionVertex().getWorkerActor()); + exeVertexChannelMap.put(edge.getSourceExecutionVertex().getExecutionVertexId(), channelId); + } + + List outputEdges = getOutputEdges(); + for (ExecutionEdge edge : outputEdges) { + String channelId = + ChannelId.genIdStr( + getExecutionVertexId(), + edge.getTargetExecutionVertex().getExecutionVertexId(), + getBuildTime()); + outputChannelIdList.add(channelId); + outputActorList.add(edge.getTargetExecutionVertex().getWorkerActor()); + exeVertexChannelMap.put(edge.getTargetExecutionVertex().getExecutionVertexId(), channelId); + } + } + + private Map generateResources(ResourceConfig resourceConfig) { + Map resourceMap = new HashMap<>(); + if (resourceConfig.isTaskCpuResourceLimit()) { + resourceMap.put(ResourceType.CPU.name(), resourceConfig.taskCpuResource()); + } + if (resourceConfig.isTaskMemResourceLimit()) { + resourceMap.put(ResourceType.MEM.name(), resourceConfig.taskMemResource()); + } + return resourceMap; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ExecutionVertex) { + return this.executionVertexId == ((ExecutionVertex) obj).getExecutionVertexId(); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(executionVertexId, outputEdges); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("id", executionVertexId) + .add("name", getExecutionVertexName()) + .add("resources", resource) + .add("state", state) + .add("containerId", containerId) + .add("workerActor", workerActor) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertexState.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertexState.java new file mode 100644 index 00000000..5227c988 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertexState.java @@ -0,0 +1,27 @@ +package io.ray.streaming.runtime.core.graph.executiongraph; + +import java.io.Serializable; + +/** Vertex state. */ +public enum ExecutionVertexState implements Serializable { + + /** Vertex(Worker) to be added. */ + TO_ADD(1, "TO_ADD"), + + /** Vertex(Worker) to be deleted. */ + TO_DEL(2, "TO_DEL"), + + /** Vertex(Worker) is running. */ + RUNNING(3, "RUNNING"), + + /** Unknown status, */ + UNKNOWN(-1, "UNKNOWN"); + + public final int code; + public final String msg; + + ExecutionVertexState(int code, String msg) { + this.code = code; + this.msg = msg; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/OneInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/OneInputProcessor.java new file mode 100644 index 00000000..faad2fc9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/OneInputProcessor.java @@ -0,0 +1,29 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.OneInputOperator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OneInputProcessor extends StreamProcessor, OneInputOperator> { + + private static final Logger LOGGER = LoggerFactory.getLogger(OneInputProcessor.class); + + public OneInputProcessor(OneInputOperator operator) { + super(operator); + } + + @Override + public void process(Record record) { + try { + this.operator.processElement(record); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + this.operator.close(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java new file mode 100644 index 00000000..971e7324 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java @@ -0,0 +1,32 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.operator.OneInputOperator; +import io.ray.streaming.operator.OperatorType; +import io.ray.streaming.operator.SourceOperator; +import io.ray.streaming.operator.StreamOperator; +import io.ray.streaming.operator.TwoInputOperator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ProcessBuilder { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProcessBuilder.class); + + public static StreamProcessor buildProcessor(StreamOperator streamOperator) { + OperatorType type = streamOperator.getOpType(); + LOGGER.info( + "Building StreamProcessor, operator type = {}, operator = {}.", + type, + streamOperator.getClass().getSimpleName()); + switch (type) { + case SOURCE: + return new SourceProcessor<>((SourceOperator) streamOperator); + case ONE_INPUT: + return new OneInputProcessor<>((OneInputOperator) streamOperator); + case TWO_INPUT: + return new TwoInputProcessor((TwoInputOperator) streamOperator); + default: + throw new RuntimeException("current operator type is not support"); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java new file mode 100644 index 00000000..c323b968 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; +import java.io.Serializable; +import java.util.List; + +public interface Processor extends Serializable { + + void open(List collectors, RuntimeContext runtimeContext); + + void process(T t); + + /** See {@link Function#saveCheckpoint()}. */ + Serializable saveCheckpoint(); + + /** See {@link Function#loadCheckpoint(Serializable)}. */ + void loadCheckpoint(Serializable checkpointObject); + + void close(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java new file mode 100644 index 00000000..802ef122 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java @@ -0,0 +1,28 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.SourceOperator; + +/** + * The processor for the stream sources, containing a SourceOperator. + * + * @param The type of source data. + */ +public class SourceProcessor extends StreamProcessor> { + + public SourceProcessor(SourceOperator operator) { + super(operator); + } + + @Override + public void process(Record record) { + throw new UnsupportedOperationException("SourceProcessor should not process record"); + } + + public void fetch() { + operator.fetch(); + } + + @Override + public void close() {} +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/StreamProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/StreamProcessor.java new file mode 100644 index 00000000..f27bd747 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/StreamProcessor.java @@ -0,0 +1,53 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.operator.Operator; +import java.io.Serializable; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * StreamingProcessor is a process unit for a operator. + * + * @param The type of process data. + * @param

Type of the specific operator class. + */ +public abstract class StreamProcessor implements Processor { + + private static final Logger LOGGER = LoggerFactory.getLogger(StreamProcessor.class); + + protected List collectors; + protected RuntimeContext runtimeContext; + protected P operator; + + public StreamProcessor(P operator) { + this.operator = operator; + } + + @Override + public void open(List collectors, RuntimeContext runtimeContext) { + this.collectors = collectors; + this.runtimeContext = runtimeContext; + if (operator != null) { + this.operator.open(collectors, runtimeContext); + } + LOGGER.info("opened {}", this); + } + + @Override + public Serializable saveCheckpoint() { + return operator.saveCheckpoint(); + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + operator.loadCheckpoint(checkpointObject); + } + + @Override + public String toString() { + return this.getClass().getSimpleName(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/TwoInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/TwoInputProcessor.java new file mode 100644 index 00000000..db3222ee --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/TwoInputProcessor.java @@ -0,0 +1,52 @@ +package io.ray.streaming.runtime.core.processor; + +import io.ray.streaming.message.Record; +import io.ray.streaming.operator.TwoInputOperator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TwoInputProcessor extends StreamProcessor> { + + private static final Logger LOGGER = LoggerFactory.getLogger(TwoInputProcessor.class); + + private String leftStream; + private String rightStream; + + public TwoInputProcessor(TwoInputOperator operator) { + super(operator); + } + + @Override + public void process(Record record) { + try { + if (record.getStream().equals(leftStream)) { + this.operator.processElement(record, null); + } else { + this.operator.processElement(null, record); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + this.operator.close(); + } + + public String getLeftStream() { + return leftStream; + } + + public void setLeftStream(String leftStream) { + this.leftStream = leftStream; + } + + public String getRightStream() { + return rightStream; + } + + public void setRightStream(String rightStream) { + this.rightStream = rightStream; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Container.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Container.java new file mode 100644 index 00000000..dc3bd6ca --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Container.java @@ -0,0 +1,185 @@ +package io.ray.streaming.runtime.core.resource; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import io.ray.api.id.UniqueId; +import io.ray.api.runtimecontext.NodeInfo; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Container is physical resource abstraction. It identifies the available resources(cpu,mem,etc.) + * and allocated actors. + */ +public class Container implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(Container.class); + + /** container id */ + private ContainerId id; + + /** Container address */ + private String address; + + /** Container hostname */ + private String hostname; + + /** Container unique id fetched from raylet */ + private UniqueId nodeId; + + /** Container available resources */ + private Map availableResources = new HashMap<>(); + + /** List of {@link ExecutionVertex} ids belong to the container. */ + private List executionVertexIds = new ArrayList<>(); + + /** Capacity is max actor number could be allocated in the container */ + private int capacity = 0; + + public Container() {} + + public Container( + String address, UniqueId nodeId, String hostname, Map availableResources) { + + this.id = new ContainerId(); + this.address = address; + this.hostname = hostname; + this.nodeId = nodeId; + this.availableResources = availableResources; + } + + public static Container from(NodeInfo nodeInfo) { + return new Container( + nodeInfo.nodeAddress, nodeInfo.nodeId, nodeInfo.nodeHostname, nodeInfo.resources); + } + + public ContainerId getId() { + return id; + } + + public void setId(ContainerId id) { + this.id = id; + } + + public String getName() { + return id.toString(); + } + + public String getAddress() { + return address; + } + + public UniqueId getNodeId() { + return nodeId; + } + + public String getHostname() { + return hostname; + } + + public Map getAvailableResources() { + return availableResources; + } + + public int getCapacity() { + return capacity; + } + + public void updateCapacity(int capacity) { + LOG.info("Update container capacity, old value: {}, new value: {}.", this.capacity, capacity); + this.capacity = capacity; + } + + public int getRemainingCapacity() { + return capacity - getAllocatedActorNum(); + } + + public int getAllocatedActorNum() { + return executionVertexIds.size(); + } + + public boolean isFull() { + return getAllocatedActorNum() >= capacity; + } + + public boolean isEmpty() { + return getAllocatedActorNum() == 0; + } + + public void allocateActor(ExecutionVertex vertex) { + LOG.info("Allocating vertex [{}] in container [{}].", vertex, this); + + executionVertexIds.add(vertex.getExecutionVertexId()); + vertex.setContainerIfNotExist(this.getId()); + decreaseResource(vertex.getResource()); + } + + public void releaseActor(ExecutionVertex vertex) { + LOG.info("Release actor, vertex: {}, container: {}.", vertex, vertex.getContainerId()); + if (executionVertexIds.contains(vertex.getExecutionVertexId())) { + executionVertexIds.removeIf(id -> id == vertex.getExecutionVertexId()); + reclaimResource(vertex.getResource()); + } else { + throw new RuntimeException( + String.format( + "Current container [%s] not found vertex [%s].", + this, vertex.getExecutionJobVertexName())); + } + } + + public List getExecutionVertexIds() { + return executionVertexIds; + } + + private void decreaseResource(Map allocatedResource) { + allocatedResource.forEach( + (k, v) -> { + Preconditions.checkArgument( + this.availableResources.get(k) >= v, + String.format( + "Available resource %s not >= decreased resource %s", + this.availableResources.get(k), v)); + Double newValue = this.availableResources.get(k) - v; + LOG.info( + "Decrease container {} resource [{}], from {} to {}.", + this.address, + k, + this.availableResources.get(k), + newValue); + this.availableResources.put(k, newValue); + }); + } + + private void reclaimResource(Map allocatedResource) { + allocatedResource.forEach( + (k, v) -> { + Double newValue = this.availableResources.get(k) + v; + LOG.info( + "Reclaim container {} resource [{}], from {} to {}.", + this.address, + k, + this.availableResources.get(k), + newValue); + this.availableResources.put(k, newValue); + }); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("id", id) + .add("address", address) + .add("hostname", hostname) + .add("nodeId", nodeId) + .add("availableResources", availableResources) + .add("executionVertexIds", executionVertexIds) + .add("capacity", capacity) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ContainerId.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ContainerId.java new file mode 100644 index 00000000..a0b08ad2 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ContainerId.java @@ -0,0 +1,6 @@ +package io.ray.streaming.runtime.core.resource; + +import io.ray.streaming.runtime.core.common.AbstractId; + +/** Container unique identifier. */ +public class ContainerId extends AbstractId {} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ResourceType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ResourceType.java new file mode 100644 index 00000000..e6b5d0ef --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/ResourceType.java @@ -0,0 +1,24 @@ +package io.ray.streaming.runtime.core.resource; + +/** Key for different type of resources. */ +public enum ResourceType { + + /** Cpu resource key. */ + CPU("CPU"), + + /** Gpu resource key. */ + GPU("GPU"), + + /** Memory resource key. */ + MEM("MEM"); + + private String value; + + ResourceType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Resources.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Resources.java new file mode 100644 index 00000000..9b07d131 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/resource/Resources.java @@ -0,0 +1,61 @@ +package io.ray.streaming.runtime.core.resource; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.ray.api.id.UniqueId; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Resource description of ResourceManager. */ +public class Resources implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(Resources.class); + + /** Available containers registered to ResourceManager. */ + private List registerContainers = new ArrayList<>(); + + public Resources() {} + + /** + * Get registered containers, the container list is read-only. + * + * @return container list. + */ + public ImmutableList getRegisteredContainers() { + return ImmutableList.copyOf(registerContainers); + } + + public void registerContainer(Container container) { + LOG.info("Add container {} to registry list.", container); + this.registerContainers.add(container); + } + + public void unRegisterContainer(List deletedUniqueIds) { + Iterator iter = registerContainers.iterator(); + while (iter.hasNext()) { + Container deletedContainer = iter.next(); + if (deletedUniqueIds.contains(deletedContainer.getNodeId())) { + LOG.info("Remove container {} from registry list.", deletedContainer); + iter.remove(); + } + } + } + + public ImmutableMap getRegisteredContainerMap() { + return ImmutableMap.copyOf( + registerContainers.stream() + .collect(java.util.stream.Collectors.toMap(Container::getNodeId, c -> c))); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("registerContainers", registerContainers) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java new file mode 100644 index 00000000..fd672978 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java @@ -0,0 +1,269 @@ +package io.ray.streaming.runtime.master; + +import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.api.id.ActorId; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.config.StreamingMasterConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.ContextBackendFactory; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.coordinator.CheckpointCoordinator; +import io.ray.streaming.runtime.master.coordinator.FailoverCoordinator; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; +import io.ray.streaming.runtime.master.resourcemanager.ResourceManager; +import io.ray.streaming.runtime.master.resourcemanager.ResourceManagerImpl; +import io.ray.streaming.runtime.master.scheduler.JobSchedulerImpl; +import io.ray.streaming.runtime.util.CheckpointStateUtil; +import io.ray.streaming.runtime.util.ResourceUtil; +import io.ray.streaming.runtime.util.Serializer; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.Map; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JobMaster is the core controller in streaming job as a ray actor. It is responsible for all the + * controls facing the {@link JobWorker}. + */ +public class JobMaster { + + private static final Logger LOG = LoggerFactory.getLogger(JobMaster.class); + + private JobMasterRuntimeContext runtimeContext; + private ResourceManager resourceManager; + private JobSchedulerImpl scheduler; + private GraphManager graphManager; + private StreamingMasterConfig conf; + + private ContextBackend contextBackend; + + private ActorHandle jobMasterActor; + + // coordinators + private CheckpointCoordinator checkpointCoordinator; + private FailoverCoordinator failoverCoordinator; + + public JobMaster(Map confMap) { + LOG.info("Creating job master with conf: {}.", confMap); + + StreamingConfig streamingConfig = new StreamingConfig(confMap); + this.conf = streamingConfig.masterConfig; + this.contextBackend = ContextBackendFactory.getContextBackend(this.conf); + + // init runtime context + runtimeContext = new JobMasterRuntimeContext(streamingConfig); + + // load checkpoint if is recover + if (!Ray.getRuntimeContext().isSingleProcess() + && Ray.getRuntimeContext().wasCurrentActorRestarted()) { + loadMasterCheckpoint(); + } + + LOG.info("Finished creating job master."); + } + + public static String getJobMasterRuntimeContextKey(StreamingMasterConfig conf) { + return conf.checkpointConfig.jobMasterContextCpPrefixKey() + conf.commonConfig.jobName(); + } + + private void loadMasterCheckpoint() { + LOG.info("Start to load JobMaster's checkpoint."); + // recover runtime context + byte[] bytes = + CheckpointStateUtil.get(contextBackend, getJobMasterRuntimeContextKey(getConf())); + if (bytes == null) { + LOG.warn("JobMaster got empty checkpoint from state backend. Skip loading checkpoint."); + // cp 0 was automatically saved when job started, see StreamTask. + runtimeContext.checkpointIds.add(0L); + return; + } + + this.runtimeContext = Serializer.decode(bytes); + + // FO case, triggered by ray, we need to register context when loading checkpoint + LOG.info("JobMaster recover runtime context[{}] from state backend.", runtimeContext); + init(true); + } + + /** + * Init JobMaster. To initiate or recover other components(like metrics and extra coordinators). + * + * @return init result + */ + public Boolean init(boolean isRecover) { + LOG.info("Initializing job master, isRecover={}.", isRecover); + + if (this.runtimeContext.getExecutionGraph() == null) { + LOG.error("Init job master failed. Job graphs is null."); + return false; + } + + ExecutionGraph executionGraph = graphManager.getExecutionGraph(); + Preconditions.checkArgument(executionGraph != null, "no execution graph"); + + // init coordinators + checkpointCoordinator = new CheckpointCoordinator(this); + checkpointCoordinator.start(); + failoverCoordinator = new FailoverCoordinator(this, isRecover); + failoverCoordinator.start(); + + saveContext(); + + LOG.info("Finished initializing job master."); + return true; + } + + /** + * Submit job to run: + * + *

    + *
  1. Using GraphManager to build physical plan according to the logical plan. + *
  2. Using ResourceManager to manage and allocate the resources. + *
  3. Using JobScheduler to schedule the job to run. + *
+ * + * @param jobMasterActor JobMaster actor + * @param jobGraph logical plan + * @return submit result + */ + public boolean submitJob(ActorHandle jobMasterActor, JobGraph jobGraph) { + LOG.info("Begin submitting job using logical plan: {}.", jobGraph); + + this.jobMasterActor = jobMasterActor; + + // init manager + graphManager = new GraphManagerImpl(runtimeContext); + resourceManager = new ResourceManagerImpl(runtimeContext); + + // build and set graph into runtime context + ExecutionGraph executionGraph = graphManager.buildExecutionGraph(jobGraph); + runtimeContext.setJobGraph(jobGraph); + runtimeContext.setExecutionGraph(executionGraph); + + // init scheduler + try { + scheduler = new JobSchedulerImpl(this); + scheduler.scheduleJob(graphManager.getExecutionGraph()); + } catch (Exception e) { + e.printStackTrace(); + LOG.error("Failed to submit job {}.", e, e); + return false; + } + return true; + } + + public synchronized void saveContext() { + if (runtimeContext != null && getConf() != null) { + LOG.debug("Save JobMaster context."); + + byte[] contextBytes = Serializer.encode(runtimeContext); + CheckpointStateUtil.put( + contextBackend, getJobMasterRuntimeContextKey(getConf()), contextBytes); + } + } + + public byte[] reportJobWorkerCommit(byte[] reportBytes) { + Boolean ret = false; + RemoteCall.BaseWorkerCmd reportPb; + try { + reportPb = RemoteCall.BaseWorkerCmd.parseFrom(reportBytes); + ActorId actorId = ActorId.fromBytes(reportPb.getActorId().toByteArray()); + long remoteCallCost = System.currentTimeMillis() - reportPb.getTimestamp(); + LOG.info( + "Vertex {}, request job worker commit cost {}ms, actorId={}.", + getExecutionVertex(actorId), + remoteCallCost, + actorId); + RemoteCall.WorkerCommitReport commit = + reportPb.getDetail().unpack(RemoteCall.WorkerCommitReport.class); + WorkerCommitReport report = new WorkerCommitReport(actorId, commit.getCommitCheckpointId()); + ret = checkpointCoordinator.reportJobWorkerCommit(report); + } catch (InvalidProtocolBufferException e) { + LOG.error("Parse job worker commit has exception.", e); + } + return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); + } + + public byte[] requestJobWorkerRollback(byte[] requestBytes) { + Boolean ret = false; + RemoteCall.BaseWorkerCmd requestPb; + try { + requestPb = RemoteCall.BaseWorkerCmd.parseFrom(requestBytes); + ActorId actorId = ActorId.fromBytes(requestPb.getActorId().toByteArray()); + long remoteCallCost = System.currentTimeMillis() - requestPb.getTimestamp(); + ExecutionGraph executionGraph = graphManager.getExecutionGraph(); + Optional rayActor = executionGraph.getActorById(actorId); + if (!rayActor.isPresent()) { + LOG.warn("Skip this invalid rollback, actor id {} is not found.", actorId); + return RemoteCall.BoolResult.newBuilder().setBoolRes(false).build().toByteArray(); + } + ExecutionVertex exeVertex = getExecutionVertex(actorId); + LOG.info( + "Vertex {}, request job worker rollback cost {}ms, actorId={}.", + exeVertex, + remoteCallCost, + actorId); + RemoteCall.WorkerRollbackRequest rollbackPb = + RemoteCall.WorkerRollbackRequest.parseFrom(requestPb.getDetail().getValue()); + exeVertex.setPid(rollbackPb.getWorkerPid()); + // To find old container where slot is located in. + String hostname = ""; + Optional container = + ResourceUtil.getContainerById( + resourceManager.getRegisteredContainers(), exeVertex.getContainerId()); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + WorkerRollbackRequest request = + new WorkerRollbackRequest( + actorId, rollbackPb.getExceptionMsg(), hostname, exeVertex.getPid()); + + ret = failoverCoordinator.requestJobWorkerRollback(request); + LOG.info( + "Vertex {} request rollback, exception msg : {}.", + exeVertex, + rollbackPb.getExceptionMsg()); + + } catch (Throwable e) { + LOG.error("Parse job worker rollback has exception.", e); + } + return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); + } + + private ExecutionVertex getExecutionVertex(ActorId id) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(id); + } + + public ActorHandle getJobMasterActor() { + return jobMasterActor; + } + + public JobMasterRuntimeContext getRuntimeContext() { + return runtimeContext; + } + + public ResourceManager getResourceManager() { + return resourceManager; + } + + public GraphManager getGraphManager() { + return graphManager; + } + + public StreamingMasterConfig getConf() { + return conf; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobRuntimeContext.java new file mode 100644 index 00000000..a271ff2d --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobRuntimeContext.java @@ -0,0 +1,56 @@ +package io.ray.streaming.runtime.master; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import java.io.Serializable; + +/** + * Runtime context for job master. + * + *

Including: graph, resource, checkpoint info, etc. + */ +public class JobRuntimeContext implements Serializable { + + private StreamingConfig conf; + private JobGraph jobGraph; + private volatile ExecutionGraph executionGraph; + + public JobRuntimeContext(StreamingConfig conf) { + this.conf = conf; + } + + public String getJobName() { + return conf.masterConfig.commonConfig.jobName(); + } + + public StreamingConfig getConf() { + return conf; + } + + public JobGraph getJobGraph() { + return jobGraph; + } + + public void setJobGraph(JobGraph jobGraph) { + this.jobGraph = jobGraph; + } + + public ExecutionGraph getExecutionGraph() { + return executionGraph; + } + + public void setExecutionGraph(ExecutionGraph executionGraph) { + this.executionGraph = executionGraph; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("jobGraph", jobGraph) + .add("executionGraph", executionGraph) + .add("conf", conf.getMap()) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java new file mode 100644 index 00000000..1765519e --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java @@ -0,0 +1,80 @@ +package io.ray.streaming.runtime.master.context; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.Sets; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Runtime context for job master, which will be stored in backend when saving checkpoint. + * + *

Including: graph, resource, checkpoint info, etc. + */ +public class JobMasterRuntimeContext implements Serializable { + + /*--------------Checkpoint----------------*/ + public volatile List checkpointIds = new ArrayList<>(); + public volatile long lastCheckpointId = 0; + public volatile long lastCpTimestamp = 0; + public volatile BlockingQueue cpCmds = new LinkedBlockingQueue<>(); + /*--------------Failover----------------*/ + public volatile BlockingQueue foCmds = new ArrayBlockingQueue<>(8192); + public volatile Set unfinishedFoCmds = Sets.newConcurrentHashSet(); + private StreamingConfig conf; + private JobGraph jobGraph; + private volatile ExecutionGraph executionGraph; + + public JobMasterRuntimeContext(StreamingConfig conf) { + this.conf = conf; + } + + public String getJobName() { + return conf.masterConfig.commonConfig.jobName(); + } + + public StreamingConfig getConf() { + return conf; + } + + public JobGraph getJobGraph() { + return jobGraph; + } + + public void setJobGraph(JobGraph jobGraph) { + this.jobGraph = jobGraph; + } + + public ExecutionGraph getExecutionGraph() { + return executionGraph; + } + + public void setExecutionGraph(ExecutionGraph executionGraph) { + this.executionGraph = executionGraph; + } + + public Long getLastValidCheckpointId() { + if (checkpointIds.isEmpty()) { + // OL is invalid checkpoint id, worker will pass it + return 0L; + } + return checkpointIds.get(checkpointIds.size() - 1); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("jobGraph", jobGraph) + .add("executionGraph", executionGraph) + .add("conf", conf.getMap()) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java new file mode 100644 index 00000000..73323da4 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java @@ -0,0 +1,45 @@ +package io.ray.streaming.runtime.master.coordinator; + +import io.ray.api.Ray; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class BaseCoordinator implements Runnable { + + private static final Logger LOG = LoggerFactory.getLogger(BaseCoordinator.class); + + protected final JobMaster jobMaster; + + protected final JobMasterRuntimeContext runtimeContext; + protected final GraphManager graphManager; + protected volatile boolean closed; + private Thread thread; + + public BaseCoordinator(JobMaster jobMaster) { + this.jobMaster = jobMaster; + this.runtimeContext = jobMaster.getRuntimeContext(); + this.graphManager = jobMaster.getGraphManager(); + } + + public void start() { + thread = + new Thread( + Ray.wrapRunnable(this), this.getClass().getName() + "-" + System.currentTimeMillis()); + thread.start(); + } + + public void stop() { + closed = true; + + try { + if (thread != null) { + thread.join(30000); + } + } catch (InterruptedException e) { + LOG.error("Coordinator thread exit has exception.", e); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java new file mode 100644 index 00000000..05064c54 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java @@ -0,0 +1,222 @@ +package io.ray.streaming.runtime.master.coordinator; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.id.ActorId; +import io.ray.runtime.exception.RayException; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.rpc.RemoteCallWorker; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * CheckpointCoordinator is the controller of checkpoint, responsible for triggering checkpoint, + * collecting {@link JobWorker}'s reports and calling {@link JobWorker} to clear expired checkpoints + * when new checkpoint finished. + */ +public class CheckpointCoordinator extends BaseCoordinator { + + private static final Logger LOG = LoggerFactory.getLogger(CheckpointCoordinator.class); + private final Set pendingCheckpointActors = new HashSet<>(); + private final Set interruptedCheckpointSet = new HashSet<>(); + private final int cpIntervalSecs; + private final int cpTimeoutSecs; + + public CheckpointCoordinator(JobMaster jobMaster) { + super(jobMaster); + + // get checkpoint interval from conf + this.cpIntervalSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpIntervalSecs(); + this.cpTimeoutSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpTimeoutSecs(); + + // Trigger next checkpoint in interval by reset last checkpoint timestamp. + runtimeContext.lastCpTimestamp = System.currentTimeMillis(); + } + + @Override + public void run() { + while (!closed) { + try { + final BaseWorkerCmd command = runtimeContext.cpCmds.poll(1, TimeUnit.SECONDS); + if (command != null) { + if (command instanceof WorkerCommitReport) { + processCommitReport((WorkerCommitReport) command); + } else { + interruptCheckpoint(); + } + } + + if (!pendingCheckpointActors.isEmpty()) { + // if wait commit report timeout, this cp fail, and restart next cp + if (timeoutOnWaitCheckpoint()) { + LOG.warn( + "Waiting for checkpoint {} timeout, pending cp actors is {}.", + runtimeContext.lastCheckpointId, + graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); + + interruptCheckpoint(); + } + } else { + maybeTriggerCheckpoint(); + } + } catch (Throwable e) { + LOG.error("Checkpoint coordinator occur err.", e); + try { + interruptCheckpoint(); + } catch (Throwable interruptE) { + LOG.error("Ignore interrupt checkpoint exception in catch block."); + } + } + } + LOG.warn("Checkpoint coordinator thread exit."); + } + + public Boolean reportJobWorkerCommit(WorkerCommitReport report) { + LOG.info("Report job worker commit {}.", report); + + Boolean ret = runtimeContext.cpCmds.offer(report); + if (!ret) { + LOG.warn("Report job worker commit failed, because command queue is full."); + } + return ret; + } + + private void processCommitReport(WorkerCommitReport commitReport) { + LOG.info( + "Start process commit report {}, from actor name={}.", + commitReport, + graphManager.getExecutionGraph().getActorName(commitReport.fromActorId)); + + try { + Preconditions.checkArgument( + commitReport.commitCheckpointId == runtimeContext.lastCheckpointId, + "expect checkpointId %s, but got %s", + runtimeContext.lastCheckpointId, + commitReport); + + if (!pendingCheckpointActors.contains(commitReport.fromActorId)) { + LOG.warn("Invalid commit report, skipped."); + return; + } + + pendingCheckpointActors.remove(commitReport.fromActorId); + LOG.info( + "Pending actors after this commit: {}.", + graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); + + // checkpoint finish + if (pendingCheckpointActors.isEmpty()) { + // actor finish + runtimeContext.checkpointIds.add(runtimeContext.lastCheckpointId); + + if (clearExpiredCpStateAndQueueMsg()) { + // save master context + jobMaster.saveContext(); + + LOG.info("Finish checkpoint: {}.", runtimeContext.lastCheckpointId); + } else { + LOG.warn("Fail to do checkpoint: {}.", runtimeContext.lastCheckpointId); + } + } + + LOG.info("Process commit report {} success.", commitReport); + } catch (Throwable e) { + LOG.warn("Process commit report has exception.", e); + } + } + + private void triggerCheckpoint() { + interruptedCheckpointSet.clear(); + if (LOG.isInfoEnabled()) { + LOG.info("Start trigger checkpoint {}.", runtimeContext.lastCheckpointId + 1); + } + + List allIds = graphManager.getExecutionGraph().getAllActorsId(); + // do the checkpoint + pendingCheckpointActors.addAll(allIds); + + // inc last checkpoint id + ++runtimeContext.lastCheckpointId; + + final List sourcesRet = new ArrayList<>(); + + graphManager + .getExecutionGraph() + .getSourceActors() + .forEach( + actor -> { + sourcesRet.add( + RemoteCallWorker.triggerCheckpoint(actor, runtimeContext.lastCheckpointId)); + }); + + for (ObjectRef rayObject : sourcesRet) { + if (rayObject.get() instanceof RayException) { + LOG.warn("Trigger checkpoint has exception.", (RayException) rayObject.get()); + throw (RayException) rayObject.get(); + } + } + runtimeContext.lastCpTimestamp = System.currentTimeMillis(); + LOG.info("Trigger checkpoint success."); + } + + private void interruptCheckpoint() { + // notify checkpoint timeout is time-consuming while many workers crash or + // container failover. + if (interruptedCheckpointSet.contains(runtimeContext.lastCheckpointId)) { + LOG.warn("Skip interrupt duplicated checkpoint id : {}.", runtimeContext.lastCheckpointId); + return; + } + interruptedCheckpointSet.add(runtimeContext.lastCheckpointId); + LOG.warn("Interrupt checkpoint, checkpoint id : {}.", runtimeContext.lastCheckpointId); + + List allActor = graphManager.getExecutionGraph().getAllActors(); + if (runtimeContext.lastCheckpointId > runtimeContext.getLastValidCheckpointId()) { + RemoteCallWorker.notifyCheckpointTimeoutParallel(allActor, runtimeContext.lastCheckpointId); + } + + if (!pendingCheckpointActors.isEmpty()) { + pendingCheckpointActors.clear(); + } + maybeTriggerCheckpoint(); + } + + private void maybeTriggerCheckpoint() { + if (readyToTrigger()) { + triggerCheckpoint(); + } + } + + private boolean clearExpiredCpStateAndQueueMsg() { + // queue msg must clear when first checkpoint finish + List allActor = graphManager.getExecutionGraph().getAllActors(); + if (1 == runtimeContext.checkpointIds.size()) { + Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); + RemoteCallWorker.clearExpiredCheckpointParallel(allActor, 0L, msgExpiredCheckpointId); + } + + if (runtimeContext.checkpointIds.size() > 1) { + Long stateExpiredCpId = runtimeContext.checkpointIds.remove(0); + Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); + RemoteCallWorker.clearExpiredCheckpointParallel( + allActor, stateExpiredCpId, msgExpiredCheckpointId); + } + return true; + } + + private boolean readyToTrigger() { + return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpIntervalSecs * 1000; + } + + private boolean timeoutOnWaitCheckpoint() { + return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpTimeoutSecs * 1000; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java new file mode 100644 index 00000000..5a3f235d --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java @@ -0,0 +1,311 @@ +package io.ray.streaming.runtime.master.coordinator; + +import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import io.ray.streaming.runtime.master.coordinator.command.InterruptCheckpointRequest; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; +import io.ray.streaming.runtime.rpc.async.AsyncRemoteCaller; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.util.ResourceUtil; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.apache.commons.collections.map.DefaultedMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FailoverCoordinator extends BaseCoordinator { + + private static final Logger LOG = LoggerFactory.getLogger(FailoverCoordinator.class); + + private static final int ROLLBACK_RETRY_TIME_MS = 10 * 1000; + private final Object cmdLock = new Object(); + private final AsyncRemoteCaller asyncRemoteCaller; + private long currentCascadingGroupId = 0; + private final Map isRollbacking = + DefaultedMap.decorate(new ConcurrentHashMap(), false); + + public FailoverCoordinator(JobMaster jobMaster, boolean isRecover) { + this(jobMaster, new AsyncRemoteCaller(), isRecover); + } + + public FailoverCoordinator( + JobMaster jobMaster, AsyncRemoteCaller asyncRemoteCaller, boolean isRecover) { + super(jobMaster); + + this.asyncRemoteCaller = asyncRemoteCaller; + // recover unfinished FO commands + JobMasterRuntimeContext runtimeContext = jobMaster.getRuntimeContext(); + if (isRecover) { + runtimeContext.foCmds.addAll(runtimeContext.unfinishedFoCmds); + } + runtimeContext.unfinishedFoCmds.clear(); + } + + @Override + public void run() { + while (!closed) { + try { + final BaseWorkerCmd command; + // see rollback() for lock reason + synchronized (cmdLock) { + command = jobMaster.getRuntimeContext().foCmds.poll(1, TimeUnit.SECONDS); + } + if (null == command) { + continue; + } + if (command instanceof WorkerRollbackRequest) { + jobMaster.getRuntimeContext().unfinishedFoCmds.add(command); + dealWithRollbackRequest((WorkerRollbackRequest) command); + } + } catch (Throwable e) { + LOG.error("Fo coordinator occur err.", e); + } + } + LOG.warn("Fo coordinator thread exit."); + } + + private Boolean isDuplicateRequest(WorkerRollbackRequest request) { + try { + Object[] foCmdsArray = runtimeContext.foCmds.toArray(); + for (Object cmd : foCmdsArray) { + if (request.fromActorId.equals(((BaseWorkerCmd) cmd).fromActorId)) { + return true; + } + } + } catch (Exception e) { + LOG.warn("Check request is duplicated failed.", e); + } + return false; + } + + public Boolean requestJobWorkerRollback(WorkerRollbackRequest request) { + LOG.info("Request job worker rollback {}.", request); + boolean ret; + if (!isDuplicateRequest(request)) { + ret = runtimeContext.foCmds.offer(request); + } else { + LOG.warn("Skip duplicated worker rollback request, {}.", request.toString()); + return true; + } + jobMaster.saveContext(); + if (!ret) { + LOG.warn("Request job worker rollback failed, because command queue is full."); + } + return ret; + } + + private void dealWithRollbackRequest(WorkerRollbackRequest rollbackRequest) { + LOG.info("Start deal with rollback request {}.", rollbackRequest); + + ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); + + // Reset pid for new-rollback actor. + if (null != rollbackRequest.getPid() + && !rollbackRequest.getPid().equals(WorkerRollbackRequest.DEFAULT_PID)) { + exeVertex.setPid(rollbackRequest.getPid()); + } + + if (isRollbacking.get(exeVertex)) { + LOG.info("Vertex {} is rollbacking, skip rollback again.", exeVertex); + return; + } + + String hostname = ""; + Optional container = + ResourceUtil.getContainerById( + jobMaster.getResourceManager().getRegisteredContainers(), exeVertex.getContainerId()); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + + if (rollbackRequest.isForcedRollback) { + interruptCheckpointAndRollback(rollbackRequest); + } else { + asyncRemoteCaller.checkIfNeedRollbackAsync( + exeVertex.getWorkerActor(), + res -> { + if (!res) { + LOG.info("Vertex {} doesn't need to rollback, skip it.", exeVertex); + return; + } + interruptCheckpointAndRollback(rollbackRequest); + }, + throwable -> { + LOG.error( + "Exception when calling checkIfNeedRollbackAsync, maybe vertex is dead" + + ", ignore this request, vertex={}.", + exeVertex, + throwable); + }); + } + + LOG.info("Deal with rollback request {} success.", rollbackRequest); + } + + private void interruptCheckpointAndRollback(WorkerRollbackRequest rollbackRequest) { + // assign a cascadingGroupId + if (rollbackRequest.cascadingGroupId == null) { + rollbackRequest.cascadingGroupId = currentCascadingGroupId++; + } + // get last valid checkpoint id then call worker rollback + rollback( + jobMaster.getRuntimeContext().getLastValidCheckpointId(), + rollbackRequest, + currentCascadingGroupId); + // we interrupt current checkpoint for 2 considerations: + // 1. current checkpoint might be timeout, because barrier might be lost after failover. so we + // interrupt current checkpoint to avoid waiting. + // 2. when we want to rollback vertex to n, job finished checkpoint n+1 and cleared state + // of checkpoint n. + jobMaster.getRuntimeContext().cpCmds.offer(new InterruptCheckpointRequest()); + } + + /** + * call worker rollback, and deal with it's reports. callback won't be finished until the entire + * DAG back to normal. + * + * @param checkpointId checkpointId to be rollback + * @param rollbackRequest worker rollback request + * @param cascadingGroupId all rollback of a cascading group should have same ID + */ + private void rollback( + long checkpointId, WorkerRollbackRequest rollbackRequest, long cascadingGroupId) { + ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); + LOG.info( + "Call vertex {} to rollback, checkpoint id is {}, cascadingGroupId={}.", + exeVertex, + checkpointId, + cascadingGroupId); + + isRollbacking.put(exeVertex, true); + + asyncRemoteCaller.rollback( + exeVertex.getWorkerActor(), + checkpointId, + result -> { + List newRollbackRequests = new ArrayList<>(); + switch (result.getResultEnum()) { + case SUCCESS: + ChannelRecoverInfo recoverInfo = result.getResultObj(); + LOG.info( + "Vertex {} rollback done, dataLostQueues={}, msg={}, cascadingGroupId={}.", + exeVertex, + recoverInfo.getDataLostQueues(), + result.getResultMsg(), + cascadingGroupId); + // rollback upstream if vertex reports abnormal input queues + newRollbackRequests = + cascadeUpstreamActors( + recoverInfo.getDataLostQueues(), exeVertex, cascadingGroupId); + break; + case SKIPPED: + LOG.info( + "Vertex skip rollback, result = {}, cascadingGroupId={}.", + result, + cascadingGroupId); + break; + default: + LOG.error( + "Rollback vertex {} failed, result={}, cascadingGroupId={}," + + " rollback this worker again after {} ms.", + exeVertex, + result, + cascadingGroupId, + ROLLBACK_RETRY_TIME_MS); + Thread.sleep(ROLLBACK_RETRY_TIME_MS); + LOG.info( + "Add rollback request for {} again, cascadingGroupId={}.", + exeVertex, + cascadingGroupId); + newRollbackRequests.add( + new WorkerRollbackRequest(exeVertex, "", "Rollback failed, try again.", false)); + break; + } + + // lock to avoid executing new rollback requests added. + // consider such a case: A->B->C, C cascade B, and B cascade A + // if B is rollback before B's rollback request is saved, and then JobMaster crashed, + // then A will never be rollback. + synchronized (cmdLock) { + jobMaster.getRuntimeContext().foCmds.addAll(newRollbackRequests); + // this rollback request is finished, remove it. + jobMaster.getRuntimeContext().unfinishedFoCmds.remove(rollbackRequest); + jobMaster.saveContext(); + } + isRollbacking.put(exeVertex, false); + }, + throwable -> { + LOG.error("Exception when calling vertex to rollback, vertex={}.", exeVertex, throwable); + isRollbacking.put(exeVertex, false); + }); + + LOG.info("Finish rollback vertex {}, checkpoint id is {}.", exeVertex, checkpointId); + } + + private List cascadeUpstreamActors( + Set dataLostQueues, ExecutionVertex fromVertex, long cascadingGroupId) { + List cascadedRollbackRequest = new ArrayList<>(); + // rollback upstream if vertex reports abnormal input queues + dataLostQueues.forEach( + q -> { + BaseActorHandle upstreamActor = + graphManager.getExecutionGraph().getPeerActor(fromVertex.getWorkerActor(), q); + ExecutionVertex upstreamExeVertex = getExecutionVertex(upstreamActor); + // vertexes that has already cascaded by other vertex in the same level + // of graph should be ignored. + if (isRollbacking.get(upstreamExeVertex)) { + return; + } + LOG.info( + "Call upstream vertex {} of vertex {} to rollback, cascadingGroupId={}.", + upstreamExeVertex, + fromVertex, + cascadingGroupId); + String hostname = ""; + Optional container = + ResourceUtil.getContainerById( + jobMaster.getResourceManager().getRegisteredContainers(), + upstreamExeVertex.getContainerId()); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + // force upstream vertexes to rollback + WorkerRollbackRequest upstreamRequest = + new WorkerRollbackRequest( + upstreamExeVertex, + hostname, + String.format("Cascading rollback from %s", fromVertex), + true); + upstreamRequest.cascadingGroupId = cascadingGroupId; + cascadedRollbackRequest.add(upstreamRequest); + }); + return cascadedRollbackRequest; + } + + private ExecutionVertex getExeVertexFromRequest(WorkerRollbackRequest rollbackRequest) { + ActorId actorId = rollbackRequest.fromActorId; + Optional rayActor = graphManager.getExecutionGraph().getActorById(actorId); + if (!rayActor.isPresent()) { + throw new RuntimeException("Can not find ray actor of ID " + actorId); + } + return getExecutionVertex(rollbackRequest.fromActorId); + } + + private ExecutionVertex getExecutionVertex(BaseActorHandle actor) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(actor.getId()); + } + + private ExecutionVertex getExecutionVertex(ActorId actorId) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(actorId); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java new file mode 100644 index 00000000..b222b3b5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java @@ -0,0 +1,15 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import io.ray.api.id.ActorId; +import java.io.Serializable; + +public abstract class BaseWorkerCmd implements Serializable { + + public ActorId fromActorId; + + public BaseWorkerCmd() {} + + protected BaseWorkerCmd(ActorId actorId) { + this.fromActorId = actorId; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java new file mode 100644 index 00000000..28843296 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java @@ -0,0 +1,3 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +public final class InterruptCheckpointRequest extends BaseWorkerCmd {} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java new file mode 100644 index 00000000..7750ce1b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import com.google.common.base.MoreObjects; +import io.ray.api.id.ActorId; + +public final class WorkerCommitReport extends BaseWorkerCmd { + + public final long commitCheckpointId; + + public WorkerCommitReport(ActorId actorId, long commitCheckpointId) { + super(actorId); + this.commitCheckpointId = commitCheckpointId; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("commitCheckpointId", commitCheckpointId) + .add("fromActorId", fromActorId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java new file mode 100644 index 00000000..df759630 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java @@ -0,0 +1,58 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import com.google.common.base.MoreObjects; +import io.ray.api.id.ActorId; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; + +public final class WorkerRollbackRequest extends BaseWorkerCmd { + + public static String DEFAULT_PID = "UNKNOWN_PID"; + public Long cascadingGroupId = null; + public boolean isForcedRollback = false; + private String exceptionMsg = "No detail message."; + private String hostname = "UNKNOWN_HOST"; + private String pid = DEFAULT_PID; + + public WorkerRollbackRequest(ActorId actorId) { + super(actorId); + } + + public WorkerRollbackRequest(ActorId actorId, String msg) { + super(actorId); + exceptionMsg = msg; + } + + public WorkerRollbackRequest( + ExecutionVertex executionVertex, String hostname, String msg, boolean isForcedRollback) { + + super(executionVertex.getWorkerActorId()); + + this.hostname = hostname; + this.pid = executionVertex.getPid(); + this.exceptionMsg = msg; + this.isForcedRollback = isForcedRollback; + } + + public WorkerRollbackRequest(ActorId actorId, String msg, String hostname, String pid) { + this(actorId, msg); + this.hostname = hostname; + this.pid = pid; + } + + public String getRollbackExceptionMsg() { + return exceptionMsg; + } + + public String getHostname() { + return hostname; + } + + public String getPid() { + return pid; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("fromActorId", fromActorId).toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManager.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManager.java new file mode 100644 index 00000000..b563917d --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManager.java @@ -0,0 +1,40 @@ +package io.ray.streaming.runtime.master.graphmanager; + +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; + +/** + * Graph manager is one of the important roles of JobMaster. It mainly focuses on graph management. + * + *

Such as: + * + *

    + *
  1. Build execution graph from job graph. + *
  2. Do modifications or operations on graph. + *
  3. Query vertex info from graph. + *
+ */ +public interface GraphManager { + + /** + * Build execution graph from job graph. + * + * @param jobGraph logical plan of streaming job. + * @return physical plan of streaming job. + */ + ExecutionGraph buildExecutionGraph(JobGraph jobGraph); + + /** + * Get job graph. + * + * @return the job graph. + */ + JobGraph getJobGraph(); + + /** + * Get execution graph. + * + * @return the execution graph. + */ + ExecutionGraph getExecutionGraph(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java new file mode 100644 index 00000000..b33b95a3 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java @@ -0,0 +1,130 @@ +package io.ray.streaming.runtime.master.graphmanager; + +import io.ray.api.BaseActorHandle; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.jobgraph.JobVertex; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobEdge; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class GraphManagerImpl implements GraphManager { + + private static final Logger LOG = LoggerFactory.getLogger(GraphManagerImpl.class); + + protected final JobMasterRuntimeContext runtimeContext; + + public GraphManagerImpl(JobMasterRuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + } + + @Override + public ExecutionGraph buildExecutionGraph(JobGraph jobGraph) { + LOG.info("Begin build execution graph with job graph {}.", jobGraph); + + // setup structure + ExecutionGraph executionGraph = setupStructure(jobGraph); + + // set max parallelism + int maxParallelism = + jobGraph.getJobVertices().stream() + .map(JobVertex::getParallelism) + .max(Integer::compareTo) + .get(); + executionGraph.setMaxParallelism(maxParallelism); + + // set job config + executionGraph.setJobConfig(jobGraph.getJobConfig()); + + LOG.info("Build execution graph success."); + return executionGraph; + } + + private ExecutionGraph setupStructure(JobGraph jobGraph) { + ExecutionGraph executionGraph = new ExecutionGraph(jobGraph.getJobName()); + Map jobConfig = jobGraph.getJobConfig(); + + // create vertex + Map exeJobVertexMap = new LinkedHashMap<>(); + Map executionVertexMap = new HashMap<>(); + long buildTime = executionGraph.getBuildTime(); + for (JobVertex jobVertex : jobGraph.getJobVertices()) { + int jobVertexId = jobVertex.getVertexId(); + exeJobVertexMap.put( + jobVertexId, + new ExecutionJobVertex( + jobVertex, jobConfig, executionGraph.getExecutionVertexIdGenerator(), buildTime)); + } + + // for each job edge, connect all source exeVertices and target exeVertices + jobGraph + .getJobEdges() + .forEach( + jobEdge -> { + ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId()); + ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId()); + + ExecutionJobEdge executionJobEdge = new ExecutionJobEdge(source, target, jobEdge); + + source.getOutputEdges().add(executionJobEdge); + target.getInputEdges().add(executionJobEdge); + + source + .getExecutionVertices() + .forEach( + sourceExeVertex -> { + target + .getExecutionVertices() + .forEach( + targetExeVertex -> { + // pre-process some mappings + executionVertexMap.put( + targetExeVertex.getExecutionVertexId(), targetExeVertex); + executionVertexMap.put( + sourceExeVertex.getExecutionVertexId(), sourceExeVertex); + // build execution edge + ExecutionEdge executionEdge = + new ExecutionEdge( + sourceExeVertex, targetExeVertex, executionJobEdge); + sourceExeVertex.getOutputEdges().add(executionEdge); + targetExeVertex.getInputEdges().add(executionEdge); + }); + }); + }); + + // set execution job vertex into execution graph + executionGraph.setExecutionJobVertexMap(exeJobVertexMap); + executionGraph.setExecutionVertexMap(executionVertexMap); + + return executionGraph; + } + + private void addActorToChannelGroupedActors( + Map> channelGroupedActors, + String channelId, + BaseActorHandle actor) { + + Set actorSet = + channelGroupedActors.computeIfAbsent(channelId, k -> new HashSet<>()); + actorSet.add(actor); + } + + @Override + public JobGraph getJobGraph() { + return runtimeContext.getJobGraph(); + } + + @Override + public ExecutionGraph getExecutionGraph() { + return runtimeContext.getExecutionGraph(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceAssignmentView.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceAssignmentView.java new file mode 100644 index 00000000..ea6fb098 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceAssignmentView.java @@ -0,0 +1,16 @@ +package io.ray.streaming.runtime.master.resourcemanager; + +import io.ray.streaming.runtime.core.resource.ContainerId; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Cluster resource allocation view, used to statically view cluster resource information. */ +public class ResourceAssignmentView extends HashMap> { + + public static ResourceAssignmentView of(Map> assignmentView) { + ResourceAssignmentView view = new ResourceAssignmentView(); + view.putAll(assignmentView); + return view; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManager.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManager.java new file mode 100644 index 00000000..fbe3f696 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManager.java @@ -0,0 +1,16 @@ +package io.ray.streaming.runtime.master.resourcemanager; + +import com.google.common.collect.ImmutableList; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy; + +/** ResourceManager(RM) is responsible for resource de-/allocation and monitoring ray cluster. */ +public interface ResourceManager extends ResourceAssignStrategy { + + /** + * Get registered containers, the container list is read-only. + * + * @return the registered container list + */ + ImmutableList getRegisteredContainers(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java new file mode 100644 index 00000000..5c180d2b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java @@ -0,0 +1,177 @@ +package io.ray.streaming.runtime.master.resourcemanager; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.ray.api.Ray; +import io.ray.api.id.UniqueId; +import io.ray.api.runtimecontext.NodeInfo; +import io.ray.streaming.runtime.config.StreamingMasterConfig; +import io.ray.streaming.runtime.config.master.ResourceConfig; +import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.Resources; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy; +import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategyFactory; +import io.ray.streaming.runtime.util.RayUtils; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ResourceManagerImpl implements ResourceManager { + + private static final Logger LOG = LoggerFactory.getLogger(ResourceManagerImpl.class); + + // Container used tag + private static final String CONTAINER_ENGAGED_KEY = "CONTAINER_ENGAGED_KEY"; + /** Resource description information. */ + private final Resources resources; + /** Timing resource updating thread */ + private final ScheduledExecutorService resourceUpdater = + new ScheduledThreadPoolExecutor( + 1, new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build()); + /** Job runtime context. */ + private JobMasterRuntimeContext runtimeContext; + /** Resource related configuration. */ + private ResourceConfig resourceConfig; + /** Slot assign strategy. */ + private ResourceAssignStrategy resourceAssignStrategy; + /** Customized actor number for each container */ + private int actorNumPerContainer; + + public ResourceManagerImpl(JobMasterRuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + StreamingMasterConfig masterConfig = runtimeContext.getConf().masterConfig; + + this.resourceConfig = masterConfig.resourceConfig; + this.resources = new Resources(); + LOG.info( + "ResourceManagerImpl begin init, conf is {}, resources are {}.", resourceConfig, resources); + + // Init custom resource configurations + this.actorNumPerContainer = resourceConfig.actorNumPerContainer(); + + ResourceAssignStrategyType resourceAssignStrategyType = + ResourceAssignStrategyType.PIPELINE_FIRST_STRATEGY; + this.resourceAssignStrategy = + ResourceAssignStrategyFactory.getStrategy(resourceAssignStrategyType); + LOG.info("Slot assign strategy: {}.", resourceAssignStrategy.getName()); + + // Init resource + initResource(); + + checkAndUpdateResourcePeriodically(); + + LOG.info("ResourceManagerImpl init success."); + } + + @Override + public ResourceAssignmentView assignResource( + List containers, ExecutionGraph executionGraph) { + return resourceAssignStrategy.assignResource(containers, executionGraph); + } + + @Override + public String getName() { + return resourceAssignStrategy.getName(); + } + + @Override + public ImmutableList getRegisteredContainers() { + LOG.info("Current resource detail: {}.", resources.toString()); + return resources.getRegisteredContainers(); + } + + /** + * Check the status of ray cluster node and update the internal resource information of streaming + * system. + */ + private void checkAndUpdateResource() { + // Get add&del nodes(node -> container) + Map latestNodeInfos = RayUtils.getAliveNodeInfoMap(); + + List addNodes = + latestNodeInfos.keySet().stream().filter(this::isAddedNode).collect(Collectors.toList()); + + List deleteNodes = + resources.getRegisteredContainerMap().keySet().stream() + .filter(nodeId -> !latestNodeInfos.containsKey(nodeId)) + .collect(Collectors.toList()); + LOG.info( + "Latest node infos: {}, current containers: {}, add nodes: {}, delete nodes: {}.", + latestNodeInfos, + resources.getRegisteredContainers(), + addNodes, + deleteNodes); + + if (!addNodes.isEmpty() || !deleteNodes.isEmpty()) { + LOG.info("Latest node infos from GCS: {}", latestNodeInfos); + LOG.info("Resource details: {}.", resources.toString()); + LOG.info("Get add nodes info: {}, del nodes info: {}.", addNodes, deleteNodes); + + // unregister containers + unregisterDeletedContainer(deleteNodes); + + // register containers + registerNewContainers( + addNodes.stream().map(latestNodeInfos::get).collect(Collectors.toList())); + } + } + + private void registerNewContainers(List nodeInfos) { + LOG.info("Start to register containers. new add node infos are: {}.", nodeInfos); + + if (nodeInfos == null || nodeInfos.isEmpty()) { + LOG.info("NodeInfos is null or empty, skip registry."); + return; + } + + for (NodeInfo nodeInfo : nodeInfos) { + registerContainer(nodeInfo); + } + } + + private void registerContainer(final NodeInfo nodeInfo) { + LOG.info("Register container {}.", nodeInfo); + + Container container = Container.from(nodeInfo); + + // failover case: container has already allocated actors + double availableCapacity = actorNumPerContainer - container.getAllocatedActorNum(); + + // update container's available dynamic resources + container.getAvailableResources().put(container.getName(), availableCapacity); + + // update register container list + resources.registerContainer(container); + } + + private void unregisterDeletedContainer(List deletedIds) { + LOG.info("Unregister container, deleted node ids are: {}.", deletedIds); + if (null == deletedIds || deletedIds.isEmpty()) { + return; + } + resources.unRegisterContainer(deletedIds); + } + + private void initResource() { + LOG.info("Init resource."); + checkAndUpdateResource(); + } + + private void checkAndUpdateResourcePeriodically() { + long intervalSecond = resourceConfig.resourceCheckIntervalSecond(); + this.resourceUpdater.scheduleAtFixedRate( + Ray.wrapRunnable(this::checkAndUpdateResource), 0, intervalSecond, TimeUnit.SECONDS); + } + + private boolean isAddedNode(UniqueId uniqueId) { + return !resources.getRegisteredContainerMap().containsKey(uniqueId); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ViewBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ViewBuilder.java new file mode 100644 index 00000000..76cf221f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ViewBuilder.java @@ -0,0 +1,23 @@ +package io.ray.streaming.runtime.master.resourcemanager; + +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.ContainerId; +import java.util.List; +import java.util.Map; + +/** ViewBuilder describes current cluster's resource allocation detail information */ +public class ViewBuilder { + + // Default constructor for serialization. + public ViewBuilder() {} + + public static ResourceAssignmentView buildResourceAssignmentView(List containers) { + Map> assignmentView = + containers.stream() + .collect( + java.util.stream.Collectors.toMap( + Container::getId, Container::getExecutionVertexIds)); + + return ResourceAssignmentView.of(assignmentView); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategy.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategy.java new file mode 100644 index 00000000..9ce131d2 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategy.java @@ -0,0 +1,23 @@ +package io.ray.streaming.runtime.master.resourcemanager.strategy; + +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView; +import java.util.List; + +/** The ResourceAssignStrategy responsible assign {@link Container} to {@link ExecutionVertex}. */ +public interface ResourceAssignStrategy { + + /** + * Assign {@link Container} for {@link ExecutionVertex} + * + * @param containers registered container + * @param executionGraph execution graph + * @return allocating view + */ + ResourceAssignmentView assignResource(List containers, ExecutionGraph executionGraph); + + /** Get container assign strategy name */ + String getName(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategyFactory.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategyFactory.java new file mode 100644 index 00000000..dddb1fed --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/ResourceAssignStrategyFactory.java @@ -0,0 +1,24 @@ +package io.ray.streaming.runtime.master.resourcemanager.strategy; + +import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; +import io.ray.streaming.runtime.master.resourcemanager.strategy.impl.PipelineFirstStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ResourceAssignStrategyFactory { + + private static final Logger LOG = LoggerFactory.getLogger(ResourceAssignStrategyFactory.class); + + public static ResourceAssignStrategy getStrategy(final ResourceAssignStrategyType type) { + ResourceAssignStrategy strategy = null; + LOG.info("Slot assign strategy is: {}.", type); + switch (type) { + case PIPELINE_FIRST_STRATEGY: + strategy = new PipelineFirstStrategy(); + break; + default: + throw new RuntimeException("strategy config error, no impl found for " + strategy); + } + return strategy; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/impl/PipelineFirstStrategy.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/impl/PipelineFirstStrategy.java new file mode 100644 index 00000000..48f2366c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/strategy/impl/PipelineFirstStrategy.java @@ -0,0 +1,222 @@ +package io.ray.streaming.runtime.master.resourcemanager.strategy.impl; + +import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.ResourceType; +import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView; +import io.ray.streaming.runtime.master.resourcemanager.ViewBuilder; +import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy; +import io.ray.streaming.runtime.master.scheduler.ScheduleException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Based on Ray dynamic resource function, resource details(by ray gcs get) and execution logic + * diagram, PipelineFirstStrategy provides a actor scheduling strategies to make the cluster load + * balanced and controllable scheduling. Assume that we have 2 containers and have a DAG graph + * composed of a source node with parallelism of 2 and a sink node with parallelism of 2, the + * structure will be like: + * + *
+ *   container_0
+ *             |- source_1
+ *             |- sink_1
+ *   container_1
+ *             |- source_2
+ *             |- sink_2
+ * 
+ */ +public class PipelineFirstStrategy implements ResourceAssignStrategy { + + public static final Logger LOG = LoggerFactory.getLogger(PipelineFirstStrategy.class); + + private int currentContainerIndex = 0; + + /** + * Assign resource to each execution vertex in the given execution graph. + * + * @param containers registered containers + * @param executionGraph execution graph + * @return allocating map, key is container ID, value is list of vertextId, and contains vertices + */ + @Override + public ResourceAssignmentView assignResource( + List containers, ExecutionGraph executionGraph) { + + Map vertices = executionGraph.getExecutionJobVertexMap(); + Map vertexRemainingNum = new HashMap<>(); + + vertices.forEach( + (k, v) -> { + int size = v.getExecutionVertices().size(); + vertexRemainingNum.put(k, size); + }); + int totalExecutionVerticesNum = + vertexRemainingNum.values().stream().mapToInt(Integer::intValue).sum(); + int containerNum = containers.size(); + int capacityPerContainer = Math.max(totalExecutionVerticesNum / containerNum, 1); + + updateContainerCapacity(containers, capacityPerContainer); + + int enlargeCapacityThreshold = 0; + boolean enlarged = false; + if (capacityPerContainer * containerNum < totalExecutionVerticesNum) { + enlargeCapacityThreshold = capacityPerContainer * containerNum; + LOG.info("Need to enlarge capacity per container, threshold: {}.", enlargeCapacityThreshold); + } + LOG.info( + "Total execution vertices num: {}, container num: {}, capacity per container: {}.", + totalExecutionVerticesNum, + containerNum, + capacityPerContainer); + + int maxParallelism = executionGraph.getMaxParallelism(); + + int allocatedVertexCount = 0; + for (int i = 0; i < maxParallelism; i++) { + for (ExecutionJobVertex jobVertex : vertices.values()) { + List exeVertices = jobVertex.getExecutionVertices(); + // current job vertex assign finished + if (exeVertices.size() <= i) { + continue; + } + ExecutionVertex executionVertex = exeVertices.get(i); + Map requiredResource = executionVertex.getResource(); + if (requiredResource.containsKey(ResourceType.CPU.getValue())) { + LOG.info( + "Required resource contain {} value : {}, no limitation by default.", + ResourceType.CPU, + requiredResource.get(ResourceType.CPU.getValue())); + requiredResource.remove(ResourceType.CPU.getValue()); + } + + Container targetContainer = findMatchedContainer(requiredResource, containers); + + targetContainer.allocateActor(executionVertex); + allocatedVertexCount++; + // Once allocatedVertexCount reaches threshold, we should enlarge capacity + if (!enlarged + && enlargeCapacityThreshold > 0 + && allocatedVertexCount >= enlargeCapacityThreshold) { + updateContainerCapacity(containers, capacityPerContainer + 1); + enlarged = true; + LOG.info("Enlarge capacity per container to: {}.", containers.get(0).getCapacity()); + } + } + } + + ResourceAssignmentView allocatingView = ViewBuilder.buildResourceAssignmentView(containers); + LOG.info("Assigning resource finished, allocating map: {}.", allocatingView); + return allocatingView; + } + + @Override + public String getName() { + return ResourceAssignStrategyType.PIPELINE_FIRST_STRATEGY.getName(); + } + + /** + * Update container capacity. eg: we have 89 actors and 8 containers, capacity will be 11 when + * initialing, and will be increased to 12 when allocating actor#89, just for load balancing. + */ + private void updateContainerCapacity(List containers, int capacity) { + containers.forEach(c -> c.updateCapacity(capacity)); + } + + /** + * Find a container which matches required resource + * + * @param requiredResource required resource + * @param containers registered containers + * @return container that matches the required resource + */ + private Container findMatchedContainer( + Map requiredResource, List containers) { + + LOG.info("Check resource, required: {}.", requiredResource); + + int checkedNum = 0; + // if current container does not have enough resource, go to the next one (loop) + while (!hasEnoughResource(requiredResource, getCurrentContainer(containers))) { + checkedNum++; + forwardToNextContainer(containers); + if (checkedNum >= containers.size()) { + throw new ScheduleException( + String.format( + "No enough resource left, required resource: %s, available resource: %s.", + requiredResource, containers)); + } + } + return getCurrentContainer(containers); + } + + /** + * Check if current container has enough resource + * + * @param requiredResource required resource + * @param container container + * @return true if matches, false else + */ + private boolean hasEnoughResource(Map requiredResource, Container container) { + LOG.info("Check resource for index: {}, container: {}", currentContainerIndex, container); + + if (null == requiredResource) { + return true; + } + + if (container.isFull()) { + LOG.info("Container {} is full.", container); + return false; + } + + Map availableResource = container.getAvailableResources(); + for (Map.Entry entry : requiredResource.entrySet()) { + if (availableResource.containsKey(entry.getKey())) { + if (availableResource.get(entry.getKey()) < entry.getValue()) { + LOG.warn( + "No enough resource for container {}. required: {}, available: {}.", + container.getAddress(), + requiredResource, + availableResource); + return false; + } + } else { + LOG.warn( + "No enough resource for container {}. required: {}, available: {}.", + container.getAddress(), + requiredResource, + availableResource); + return false; + } + } + + return true; + } + + /** + * Forward to next container + * + * @param containers registered container list + * @return next container in the list + */ + private Container forwardToNextContainer(List containers) { + this.currentContainerIndex = (this.currentContainerIndex + 1) % containers.size(); + return getCurrentContainer(containers); + } + + /** + * Get current container + * + * @param containers registered container + * @return current container to allocate actor + */ + private Container getCurrentContainer(List containers) { + return containers.get(currentContainerIndex); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobScheduler.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobScheduler.java new file mode 100644 index 00000000..d0fb60d5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobScheduler.java @@ -0,0 +1,15 @@ +package io.ray.streaming.runtime.master.scheduler; + +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; + +/** Job scheduler is used to do the scheduling in JobMaster. */ +public interface JobScheduler { + + /** + * Schedule streaming job using the physical plan. + * + * @param executionGraph physical plan + * @return scheduling result + */ + boolean scheduleJob(ExecutionGraph executionGraph); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java new file mode 100644 index 00000000..039715cc --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java @@ -0,0 +1,199 @@ +package io.ray.streaming.runtime.master.scheduler; + +import com.google.common.base.Preconditions; +import io.ray.api.ActorHandle; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import io.ray.streaming.runtime.master.resourcemanager.ResourceManager; +import io.ray.streaming.runtime.master.resourcemanager.ViewBuilder; +import io.ray.streaming.runtime.master.scheduler.controller.WorkerLifecycleController; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Job scheduler implementation. */ +public class JobSchedulerImpl implements JobScheduler { + + private static final Logger LOG = LoggerFactory.getLogger(JobSchedulerImpl.class); + private final JobMaster jobMaster; + private final ResourceManager resourceManager; + private final GraphManager graphManager; + private final WorkerLifecycleController workerLifecycleController; + private StreamingConfig jobConfig; + + public JobSchedulerImpl(JobMaster jobMaster) { + this.jobMaster = jobMaster; + this.graphManager = jobMaster.getGraphManager(); + this.resourceManager = jobMaster.getResourceManager(); + this.workerLifecycleController = new WorkerLifecycleController(); + this.jobConfig = jobMaster.getRuntimeContext().getConf(); + + LOG.info("Scheduler initiated."); + } + + @Override + public boolean scheduleJob(ExecutionGraph executionGraph) { + LOG.info("Begin scheduling. Job: {}.", executionGraph.getJobName()); + + // Allocate resource then create workers + // Actor creation is in this step + prepareResourceAndCreateWorker(executionGraph); + + // now actor info is available in execution graph + // preprocess some handy mappings in execution graph + executionGraph.generateActorMappings(); + + // init worker context and start to run + initAndStart(executionGraph); + + return true; + } + + /** + * Allocate job workers' resource then create job workers' actor. + * + * @param executionGraph the physical plan + */ + protected void prepareResourceAndCreateWorker(ExecutionGraph executionGraph) { + List containers = resourceManager.getRegisteredContainers(); + + // Assign resource for execution vertices + resourceManager.assignResource(containers, executionGraph); + + LOG.info("Allocating map is: {}.", ViewBuilder.buildResourceAssignmentView(containers)); + + // Start all new added workers + createWorkers(executionGraph); + } + + /** + * Init JobMaster and JobWorkers then start JobWorkers. + * + * @param executionGraph physical plan + */ + private void initAndStart(ExecutionGraph executionGraph) { + // generate vertex - context map + Map vertexToContextMap = buildWorkersContext(executionGraph); + + // init workers + Preconditions.checkState(initWorkers(vertexToContextMap)); + + // init master + initMaster(); + + // start workers + startWorkers(executionGraph, jobMaster.getRuntimeContext().lastCheckpointId); + } + + /** + * Create JobWorker actors according to the physical plan. + * + * @param executionGraph physical plan + * @return actor creation result + */ + public boolean createWorkers(ExecutionGraph executionGraph) { + LOG.info("Begin creating workers."); + long startTs = System.currentTimeMillis(); + + // create JobWorker actors + boolean createResult = + workerLifecycleController.createWorkers(executionGraph.getAllAddedExecutionVertices()); + + if (createResult) { + LOG.info("Finished creating workers. Cost {} ms.", System.currentTimeMillis() - startTs); + return true; + } else { + LOG.error("Failed to create workers. Cost {} ms.", System.currentTimeMillis() - startTs); + return false; + } + } + + /** + * Init JobWorkers according to the vertex and context infos. + * + * @param vertexToContextMap vertex - context map + */ + protected boolean initWorkers(Map vertexToContextMap) { + boolean succeed; + int timeoutMs = jobConfig.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs(); + succeed = workerLifecycleController.initWorkers(vertexToContextMap, timeoutMs); + if (!succeed) { + LOG.error("Failed to initiate workers in {} milliseconds", timeoutMs); + } + return succeed; + } + + /** Start JobWorkers according to the physical plan. */ + public boolean startWorkers(ExecutionGraph executionGraph, long checkpointId) { + boolean result; + try { + result = + workerLifecycleController.startWorkers( + executionGraph, + checkpointId, + jobConfig.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs()); + } catch (Exception e) { + LOG.error("Failed to start workers.", e); + return false; + } + return result; + } + + /** + * Build workers context. + * + * @param executionGraph execution graph + * @return vertex to worker context map + */ + protected Map buildWorkersContext( + ExecutionGraph executionGraph) { + ActorHandle masterActor = jobMaster.getJobMasterActor(); + + // build workers' context + Map vertexToContextMap = new HashMap<>(); + executionGraph + .getAllExecutionVertices() + .forEach( + vertex -> { + JobWorkerContext context = buildJobWorkerContext(vertex, masterActor); + vertexToContextMap.put(vertex, context); + }); + return vertexToContextMap; + } + + private JobWorkerContext buildJobWorkerContext( + ExecutionVertex executionVertex, ActorHandle masterActor) { + + // create java worker context + JobWorkerContext context = new JobWorkerContext(masterActor, executionVertex); + + return context; + } + + /** + * Destroy JobWorkers according to the vertex infos. + * + * @param executionVertices specified vertices + */ + public boolean destroyWorkers(List executionVertices) { + boolean result; + try { + result = workerLifecycleController.destroyWorkers(executionVertices); + } catch (Exception e) { + LOG.error("Failed to destroy workers.", e); + return false; + } + return result; + } + + private void initMaster() { + jobMaster.init(false); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/ScheduleException.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/ScheduleException.java new file mode 100644 index 00000000..9841e036 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/ScheduleException.java @@ -0,0 +1,25 @@ +package io.ray.streaming.runtime.master.scheduler; + +public class ScheduleException extends RuntimeException { + + public ScheduleException() { + super(); + } + + public ScheduleException(String message) { + super(message); + } + + public ScheduleException(String message, Throwable cause) { + super(message, cause); + } + + public ScheduleException(Throwable cause) { + super(cause); + } + + protected ScheduleException( + String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java new file mode 100644 index 00000000..3cd3984b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java @@ -0,0 +1,213 @@ +package io.ray.streaming.runtime.master.scheduler.controller; + +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.WaitResult; +import io.ray.api.function.PyActorClass; +import io.ray.api.id.ActorId; +import io.ray.streaming.api.Language; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.python.GraphPbBuilder; +import io.ray.streaming.runtime.rpc.RemoteCallWorker; +import io.ray.streaming.runtime.worker.JobWorker; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Worker lifecycle controller is used to control JobWorker's creation, initiation and so on. */ +public class WorkerLifecycleController { + + private static final Logger LOG = LoggerFactory.getLogger(WorkerLifecycleController.class); + + public boolean createWorkers(List executionVertices) { + return asyncBatchExecute(this::createWorker, executionVertices); + } + + /** + * Create JobWorker actor according to the execution vertex. + * + * @param executionVertex target execution vertex + * @return creation result + */ + private boolean createWorker(ExecutionVertex executionVertex) { + LOG.info( + "Start to create worker actor for vertex: {} with resource: {}, workeConfig: {}.", + executionVertex.getExecutionVertexName(), + executionVertex.getResource(), + executionVertex.getWorkerConfig()); + + Language language = executionVertex.getLanguage(); + + BaseActorHandle actor; + if (Language.JAVA == language) { + actor = + Ray.actor(JobWorker::new, executionVertex) + .setResources(executionVertex.getResource()) + .setMaxRestarts(-1) + .remote(); + } else { + RemoteCall.ExecutionVertexContext.ExecutionVertex vertexPb = + new GraphPbBuilder().buildVertex(executionVertex); + actor = + Ray.actor( + PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"), + vertexPb.toByteArray()) + .setResources(executionVertex.getResource()) + .setMaxRestarts(-1) + .remote(); + } + + if (null == actor) { + LOG.error("Create worker actor failed."); + return false; + } + + executionVertex.setWorkerActor(actor); + + LOG.info( + "Worker actor created, actor: {}, vertex: {}.", + executionVertex.getWorkerActorId(), + executionVertex.getExecutionVertexName()); + return true; + } + + /** + * Using context to init JobWorker. + * + * @param vertexToContextMap target JobWorker actor + * @param timeout timeout for waiting, unit: ms + * @return initiation result + */ + public boolean initWorkers( + Map vertexToContextMap, int timeout) { + LOG.info("Begin initiating workers: {}.", vertexToContextMap); + long startTime = System.currentTimeMillis(); + + Map, ActorId> rayObjects = new HashMap<>(); + vertexToContextMap + .entrySet() + .forEach( + (entry -> { + ExecutionVertex vertex = entry.getKey(); + rayObjects.put( + RemoteCallWorker.initWorker(vertex.getWorkerActor(), entry.getValue()), + vertex.getWorkerActorId()); + })); + + List> objectRefList = new ArrayList<>(rayObjects.keySet()); + + LOG.info("Waiting for workers' initialization."); + WaitResult result = Ray.wait(objectRefList, objectRefList.size(), timeout); + if (result.getReady().size() != objectRefList.size()) { + LOG.error("Initializing workers timeout[{} ms].", timeout); + return false; + } + + LOG.info("Finished waiting workers' initialization."); + LOG.info("Workers initialized. Cost {} ms.", System.currentTimeMillis() - startTime); + return true; + } + + /** + * Start JobWorkers to run task. + * + * @param executionGraph physical plan + * @param timeout timeout for waiting, unit: ms + * @return starting result + */ + public boolean startWorkers(ExecutionGraph executionGraph, long lastCheckpointId, int timeout) { + LOG.info("Begin starting workers."); + long startTime = System.currentTimeMillis(); + List> objectRefs = new ArrayList<>(); + + // start source actors 1st + executionGraph + .getSourceActors() + .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); + + // then start non-source actors + executionGraph + .getNonSourceActors() + .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); + + WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout); + if (result.getReady().size() != objectRefs.size()) { + LOG.error("Starting workers timeout[{} ms].", timeout); + return false; + } + + LOG.info("Workers started. Cost {} ms.", System.currentTimeMillis() - startTime); + return true; + } + + /** + * Stop and destroy JobWorkers' actor. + * + * @param executionVertices target vertices + * @return destroy result + */ + public boolean destroyWorkers(List executionVertices) { + return asyncBatchExecute(this::destroyWorker, executionVertices); + } + + private boolean destroyWorker(ExecutionVertex executionVertex) { + BaseActorHandle rayActor = executionVertex.getWorkerActor(); + LOG.info( + "Begin destroying worker[vertex={}, actor={}].", + executionVertex.getExecutionVertexName(), + rayActor.getId()); + + boolean destroyResult = RemoteCallWorker.shutdownWithoutReconstruction(rayActor); + + if (!destroyResult) { + LOG.error( + "Failed to destroy JobWorker[{}]'s actor: {}.", + executionVertex.getExecutionVertexName(), + rayActor); + return false; + } + + LOG.info("Worker destroyed, actor: {}.", rayActor); + return true; + } + + /** + * Async batch execute function, for some cases that could not use Ray.wait + * + * @param operation the function to be executed + */ + private boolean asyncBatchExecute( + Function operation, List executionVertices) { + final Object asyncContext = Ray.getAsyncContext(); + + List> futureResults = + executionVertices.stream() + .map( + vertex -> + CompletableFuture.supplyAsync( + () -> { + Ray.setAsyncContext(asyncContext); + return operation.apply(vertex); + })) + .collect(Collectors.toList()); + + List succeeded = + futureResults.stream().map(CompletableFuture::join).collect(Collectors.toList()); + + if (succeeded.stream().anyMatch(x -> !x)) { + LOG.error("Not all futures return true, check ResourceManager'log the detail."); + return false; + } + return true; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java new file mode 100644 index 00000000..d1459c40 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java @@ -0,0 +1,119 @@ +package io.ray.streaming.runtime.message; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; + +public class CallResult implements Serializable { + + protected T resultObj; + private boolean success; + private int resultCode; + private String resultMsg; + + public CallResult() {} + + public CallResult(boolean success, int resultCode, String resultMsg, T resultObj) { + this.success = success; + this.resultCode = resultCode; + this.resultMsg = resultMsg; + this.resultObj = resultObj; + } + + public static CallResult success() { + return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, null); + } + + public static CallResult success(T payload) { + return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, payload); + } + + public static CallResult skipped(String msg) { + return new CallResult<>(true, CallResultEnum.SKIPPED.code, msg, null); + } + + public static CallResult fail() { + return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, null); + } + + public static CallResult fail(T payload) { + return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, payload); + } + + public static CallResult fail(String msg) { + return new CallResult<>(false, CallResultEnum.FAILED.code, msg, null); + } + + public static CallResult fail(CallResultEnum resultEnum, T payload) { + return new CallResult<>(false, resultEnum.code, resultEnum.msg, payload); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("resultObj", resultObj) + .add("success", success) + .add("resultCode", resultCode) + .add("resultMsg", resultMsg) + .toString(); + } + + public boolean isSuccess() { + return this.success; + } + + public void setSuccess(boolean success) { + this.success = success; + } + + public int getResultCode() { + return this.resultCode; + } + + public void setResultCode(int resultCode) { + this.resultCode = resultCode; + } + + public CallResultEnum getResultEnum() { + return CallResultEnum.getEnum(this.resultCode); + } + + public String getResultMsg() { + return this.resultMsg; + } + + public void setResultMsg(String resultMsg) { + this.resultMsg = resultMsg; + } + + public T getResultObj() { + return this.resultObj; + } + + public void setResultObj(T resultObj) { + this.resultObj = resultObj; + } + + public enum CallResultEnum implements Serializable { + /** call result enum */ + SUCCESS(0, "SUCCESS"), + FAILED(1, "FAILED"), + SKIPPED(2, "SKIPPED"); + + public final int code; + public final String msg; + + CallResultEnum(int code, String msg) { + this.code = code; + this.msg = msg; + } + + public static CallResultEnum getEnum(int code) { + for (CallResultEnum value : CallResultEnum.values()) { + if (code == value.code) { + return value; + } + } + return FAILED; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java new file mode 100644 index 00000000..2e3aa3cc --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -0,0 +1,152 @@ +package io.ray.streaming.runtime.python; + +import com.google.protobuf.ByteString; +import io.ray.runtime.actor.NativeActorHandle; +import io.ray.streaming.api.function.Function; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.operator.Operator; +import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonOperator; +import io.ray.streaming.python.PythonOperator.ChainedPythonOperator; +import io.ray.streaming.python.PythonPartition; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.generated.Streaming; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class GraphPbBuilder { + + private MsgPackSerializer serializer = new MsgPackSerializer(); + + public RemoteCall.ExecutionVertexContext buildExecutionVertexContext( + ExecutionVertex executionVertex) { + RemoteCall.ExecutionVertexContext.Builder builder = + RemoteCall.ExecutionVertexContext.newBuilder(); + + // build vertex + builder.setCurrentExecutionVertex(buildVertex(executionVertex)); + + // build upstream vertices + List upstreamVertices = executionVertex.getInputVertices(); + List upstreamVertexPbs = + upstreamVertices.stream().map(this::buildVertex).collect(Collectors.toList()); + builder.addAllUpstreamExecutionVertices(upstreamVertexPbs); + + // build downstream vertices + List downstreamVertices = executionVertex.getOutputVertices(); + List downstreamVertexPbs = + downstreamVertices.stream().map(this::buildVertex).collect(Collectors.toList()); + builder.addAllDownstreamExecutionVertices(downstreamVertexPbs); + + // build input edges + List inputEdges = executionVertex.getInputEdges(); + List inputEdgesPbs = + inputEdges.stream().map(this::buildEdge).collect(Collectors.toList()); + builder.addAllInputExecutionEdges(inputEdgesPbs); + + // build output edges + List outputEdges = executionVertex.getOutputEdges(); + List outputEdgesPbs = + outputEdges.stream().map(this::buildEdge).collect(Collectors.toList()); + builder.addAllOutputExecutionEdges(outputEdgesPbs); + + return builder.build(); + } + + public RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex( + ExecutionVertex executionVertex) { + // build vertex infos + RemoteCall.ExecutionVertexContext.ExecutionVertex.Builder executionVertexBuilder = + RemoteCall.ExecutionVertexContext.ExecutionVertex.newBuilder(); + executionVertexBuilder.setExecutionVertexId(executionVertex.getExecutionVertexId()); + executionVertexBuilder.setExecutionJobVertexId(executionVertex.getExecutionJobVertexId()); + executionVertexBuilder.setExecutionJobVertexName(executionVertex.getExecutionJobVertexName()); + executionVertexBuilder.setExecutionVertexIndex(executionVertex.getExecutionVertexIndex()); + executionVertexBuilder.setParallelism(executionVertex.getParallelism()); + executionVertexBuilder.setOperator( + ByteString.copyFrom(serializeOperator(executionVertex.getStreamOperator()))); + executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator())); + if (executionVertex.getWorkerActor() != null) { + executionVertexBuilder.setWorkerActor( + ByteString.copyFrom(((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); + } + executionVertexBuilder.setContainerId(executionVertex.getContainerId().toString()); + executionVertexBuilder.setBuildTime(executionVertex.getBuildTime()); + executionVertexBuilder.setLanguage( + Streaming.Language.valueOf(executionVertex.getLanguage().name())); + executionVertexBuilder.putAllConfig(executionVertex.getWorkerConfig()); + executionVertexBuilder.putAllResource(executionVertex.getResource()); + + return executionVertexBuilder.build(); + } + + private RemoteCall.ExecutionVertexContext.ExecutionEdge buildEdge(ExecutionEdge executionEdge) { + // build edge infos + RemoteCall.ExecutionVertexContext.ExecutionEdge.Builder executionEdgeBuilder = + RemoteCall.ExecutionVertexContext.ExecutionEdge.newBuilder(); + executionEdgeBuilder.setSourceExecutionVertexId(executionEdge.getSourceVertexId()); + executionEdgeBuilder.setTargetExecutionVertexId(executionEdge.getTargetVertexId()); + executionEdgeBuilder.setPartition( + ByteString.copyFrom(serializePartition(executionEdge.getPartition()))); + + return executionEdgeBuilder.build(); + } + + private byte[] serializeOperator(Operator operator) { + if (operator instanceof PythonOperator) { + if (isPythonChainedOperator(operator)) { + return serializePythonChainedOperator((ChainedPythonOperator) operator); + } else { + PythonOperator pythonOperator = (PythonOperator) operator; + return serializer.serialize( + Arrays.asList( + serializeFunction(pythonOperator.getFunction()), + pythonOperator.getModuleName(), + pythonOperator.getClassName())); + } + } else { + return new byte[0]; + } + } + + private boolean isPythonChainedOperator(Operator operator) { + return operator instanceof ChainedPythonOperator; + } + + private byte[] serializePythonChainedOperator(ChainedPythonOperator operator) { + List serializedOperators = + operator.getOperators().stream().map(this::serializeOperator).collect(Collectors.toList()); + return serializer.serialize(Arrays.asList(serializedOperators, operator.getConfigs())); + } + + private byte[] serializeFunction(Function function) { + if (function instanceof PythonFunction) { + PythonFunction pyFunc = (PythonFunction) function; + // function_bytes, module_name, function_name, function_interface + return serializer.serialize( + Arrays.asList( + pyFunc.getFunction(), pyFunc.getModuleName(), + pyFunc.getFunctionName(), pyFunc.getFunctionInterface())); + } else { + return new byte[0]; + } + } + + private byte[] serializePartition(Partition partition) { + if (partition instanceof PythonPartition) { + PythonPartition pythonPartition = (PythonPartition) partition; + // partition_bytes, module_name, function_name + return serializer.serialize( + Arrays.asList( + pythonPartition.getPartition(), + pythonPartition.getModuleName(), + pythonPartition.getFunctionName())); + } else { + return new byte[0]; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java new file mode 100644 index 00000000..4656d98c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java @@ -0,0 +1,262 @@ +package io.ray.streaming.runtime.python; + +import com.google.common.base.Preconditions; +import com.google.common.primitives.Primitives; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.Stream; +import io.ray.streaming.python.PythonFunction; +import io.ray.streaming.python.PythonPartition; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonStreamSource; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; +import io.ray.streaming.runtime.util.ReflectionUtils; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Gateway for streaming python api. All calls on DataStream in python will be mapped to DataStream + * call in java by this PythonGateway using ray calls. this class needs to be in sync with + * GatewayClient in `streaming/python/runtime/gateway_client.py` + */ +@SuppressWarnings("unchecked") +public class PythonGateway { + + private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class); + private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__"; + private static MsgPackSerializer serializer = new MsgPackSerializer(); + + private Map referenceMap; + private StreamingContext streamingContext; + + public PythonGateway() { + referenceMap = new HashMap<>(); + LOG.info("PythonGateway created"); + } + + public byte[] createStreamingContext() { + streamingContext = StreamingContext.buildContext(); + LOG.info("StreamingContext created"); + referenceMap.put(getReferenceId(streamingContext), streamingContext); + return serializer.serialize(getReferenceId(streamingContext)); + } + + public StreamingContext getStreamingContext() { + return streamingContext; + } + + public byte[] withConfig(byte[] confBytes) { + Preconditions.checkNotNull(streamingContext); + try { + Map config = (Map) serializer.deserialize(confBytes); + LOG.info("Set config {}", config); + streamingContext.withConfig(config); + // We can't use `return void`, that will make `ray.get()` hang forever. + // We can't using `return new byte[0]`, that will make `ray::CoreWorker::ExecuteTask` crash. + // So we `return new byte[1]` for method execution success. + // Same for other methods in this class which return new byte[1]. + return new byte[1]; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] createPythonStreamSource(byte[] pySourceFunc) { + Preconditions.checkNotNull(streamingContext); + try { + PythonStreamSource pythonStreamSource = + PythonStreamSource.from(streamingContext, new PythonFunction(pySourceFunc)); + referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource); + return serializer.serialize(getReferenceId(pythonStreamSource)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] execute(byte[] jobNameBytes) { + LOG.info("Starting executing"); + streamingContext.execute((String) serializer.deserialize(jobNameBytes)); + // see `withConfig` method. + return new byte[1]; + } + + public byte[] createPyFunc(byte[] pyFunc) { + PythonFunction function = new PythonFunction(pyFunc); + referenceMap.put(getReferenceId(function), function); + return serializer.serialize(getReferenceId(function)); + } + + public byte[] createPyPartition(byte[] pyPartition) { + PythonPartition partition = new PythonPartition(pyPartition); + referenceMap.put(getReferenceId(partition), partition); + return serializer.serialize(getReferenceId(partition)); + } + + public byte[] union(byte[] paramsBytes) { + List streams = (List) serializer.deserialize(paramsBytes); + streams = processParameters(streams); + LOG.info("Call union with streams {}", streams); + Preconditions.checkArgument(streams.size() >= 2, "Union needs at least two streams"); + Stream unionStream; + Stream stream1 = (Stream) streams.get(0); + List otherStreams = streams.subList(1, streams.size()); + if (stream1 instanceof DataStream) { + DataStream dataStream = (DataStream) stream1; + unionStream = dataStream.union(otherStreams); + } else { + Preconditions.checkArgument(stream1 instanceof PythonDataStream); + PythonDataStream pythonDataStream = (PythonDataStream) stream1; + unionStream = pythonDataStream.union(otherStreams); + } + return serialize(unionStream); + } + + public byte[] callFunction(byte[] paramsBytes) { + try { + List params = (List) serializer.deserialize(paramsBytes); + params = processParameters(params); + LOG.info("callFunction params {}", params); + String className = (String) params.get(0); + String funcName = (String) params.get(1); + Class clz = Class.forName(className, true, this.getClass().getClassLoader()); + Class[] paramsTypes = + params.subList(2, params.size()).stream().map(Object::getClass).toArray(Class[]::new); + Method method = findMethod(clz, funcName, paramsTypes); + Object result = method.invoke(null, params.subList(2, params.size()).toArray()); + return serialize(result); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public byte[] callMethod(byte[] paramsBytes) { + try { + List params = (List) serializer.deserialize(paramsBytes); + params = processParameters(params); + LOG.info("callMethod params {}", params); + Object obj = params.get(0); + String methodName = (String) params.get(1); + Class clz = obj.getClass(); + Class[] paramsTypes = + params.subList(2, params.size()).stream().map(Object::getClass).toArray(Class[]::new); + Method method = findMethod(clz, methodName, paramsTypes); + Object result = method.invoke(obj, params.subList(2, params.size()).toArray()); + return serialize(result); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Method findMethod(Class cls, String methodName, Class[] paramsTypes) { + List methods = ReflectionUtils.findMethods(cls, methodName); + if (methods.size() == 1) { + return methods.get(0); + } + // Convert all params types to primitive types if it's boxed type + Class[] unwrappedTypes = + Arrays.stream(paramsTypes) + .map((Function) Primitives::unwrap) + .toArray(Class[]::new); + Optional any = + methods.stream() + .filter( + m -> { + boolean exactMatch = + Arrays.equals(m.getParameterTypes(), paramsTypes) + || Arrays.equals(m.getParameterTypes(), unwrappedTypes); + if (exactMatch) { + return true; + } else if (paramsTypes.length == m.getParameterTypes().length) { + for (int i = 0; i < m.getParameterTypes().length; i++) { + Class parameterType = m.getParameterTypes()[i]; + if (!parameterType.isAssignableFrom(paramsTypes[i])) { + return false; + } + } + return true; + } else { + return false; + } + }) + .findAny(); + Preconditions.checkArgument( + any.isPresent(), + String.format( + "Method %s with type %s doesn't exist on class %s", + methodName, Arrays.toString(paramsTypes), cls)); + return any.get(); + } + + private byte[] serialize(Object value) { + if (returnReference(value)) { + referenceMap.put(getReferenceId(value), value); + return serializer.serialize(getReferenceId(value)); + } else { + return serializer.serialize(value); + } + } + + private static boolean returnReference(Object value) { + if (isBasic(value)) { + return false; + } else { + try { + serializer.serialize(value); + return false; + } catch (Exception e) { + return true; + } + } + } + + private static boolean isBasic(Object value) { + return value == null + || (value instanceof Boolean) + || (value instanceof Number) + || (value instanceof String) + || (value instanceof byte[]); + } + + public byte[] newInstance(byte[] classNameBytes) { + String className = (String) serializer.deserialize(classNameBytes); + try { + Class clz = Class.forName(className, true, Thread.currentThread().getContextClassLoader()); + Object instance = clz.newInstance(); + referenceMap.put(getReferenceId(instance), instance); + return serializer.serialize(getReferenceId(instance)); + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new IllegalArgumentException( + String.format("Create instance for class %s failed", className), e); + } + } + + private List processParameters(List params) { + return params.stream().map(this::processParameter).collect(Collectors.toList()); + } + + private Object processParameter(Object o) { + if (o instanceof String) { + Object value = referenceMap.get(o); + if (value != null) { + return value; + } + } + // Since python can't represent byte/short, we convert all Byte/Short to Integer + if (o instanceof Byte || o instanceof Short) { + return ((Number) o).intValue(); + } + return o; + } + + private String getReferenceId(Object o) { + return REFERENCE_ID_PREFIX + System.identityHashCode(o); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java new file mode 100644 index 00000000..48ba7aa7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java @@ -0,0 +1,58 @@ +package io.ray.streaming.runtime.rpc; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PbResultParser { + + private static final Logger LOG = LoggerFactory.getLogger(PbResultParser.class); + + public static Boolean parseBoolResult(byte[] result) { + if (null == result) { + LOG.warn("Result is null."); + return false; + } + + RemoteCall.BoolResult boolResult; + try { + boolResult = RemoteCall.BoolResult.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + LOG.error("Parse boolean result has exception.", e); + return false; + } + + return boolResult.getBoolRes(); + } + + public static CallResult parseRollbackResult(byte[] bytes) { + RemoteCall.CallResult callResultPb; + try { + callResultPb = RemoteCall.CallResult.parseFrom(bytes); + } catch (InvalidProtocolBufferException e) { + LOG.error("Rollback parse result has exception.", e); + return CallResult.fail(); + } + + CallResult callResult = new CallResult<>(); + callResult.setSuccess(callResultPb.getSuccess()); + callResult.setResultCode(callResultPb.getResultCode()); + callResult.setResultMsg(callResultPb.getResultMsg()); + RemoteCall.QueueRecoverInfo recoverInfo = callResultPb.getResultObj(); + Map creationStatusMap = new HashMap<>(); + recoverInfo + .getCreationStatusMap() + .forEach( + (k, v) -> { + creationStatusMap.put( + k, ChannelRecoverInfo.ChannelCreationStatus.fromInt(v.getNumber())); + }); + callResult.setResultObj(new ChannelRecoverInfo(creationStatusMap)); + return callResult; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java new file mode 100644 index 00000000..1dc69751 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java @@ -0,0 +1,51 @@ +package io.ray.streaming.runtime.rpc; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; + +public class RemoteCallMaster { + + public static ObjectRef reportJobWorkerCommitAsync( + ActorHandle actor, WorkerCommitReport commitReport) { + RemoteCall.WorkerCommitReport commit = + RemoteCall.WorkerCommitReport.newBuilder() + .setCommitCheckpointId(commitReport.commitCheckpointId) + .build(); + Any detail = Any.pack(commit); + RemoteCall.BaseWorkerCmd cmd = + RemoteCall.BaseWorkerCmd.newBuilder() + .setActorId(ByteString.copyFrom(commitReport.fromActorId.getBytes())) + .setTimestamp(System.currentTimeMillis()) + .setDetail(detail) + .build(); + + return actor.task(JobMaster::reportJobWorkerCommit, cmd.toByteArray()).remote(); + } + + public static Boolean requestJobWorkerRollback( + ActorHandle actor, WorkerRollbackRequest rollbackRequest) { + RemoteCall.WorkerRollbackRequest request = + RemoteCall.WorkerRollbackRequest.newBuilder() + .setExceptionMsg(rollbackRequest.getRollbackExceptionMsg()) + .setWorkerHostname(rollbackRequest.getHostname()) + .setWorkerPid(rollbackRequest.getPid()) + .build(); + Any detail = Any.pack(request); + RemoteCall.BaseWorkerCmd cmd = + RemoteCall.BaseWorkerCmd.newBuilder() + .setActorId(ByteString.copyFrom(rollbackRequest.fromActorId.getBytes())) + .setTimestamp(System.currentTimeMillis()) + .setDetail(detail) + .build(); + ObjectRef ret = + actor.task(JobMaster::requestJobWorkerRollback, cmd.toByteArray()).remote(); + byte[] res = ret.get(); + return PbResultParser.parseBoolResult(res); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java new file mode 100644 index 00000000..6cd78813 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java @@ -0,0 +1,201 @@ +package io.ray.streaming.runtime.rpc; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.PyActorHandle; +import io.ray.api.Ray; +import io.ray.api.function.PyActorMethod; +import io.ray.api.function.RayFunc3; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.worker.JobWorker; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Ray call worker. It takes the communication job from {@link JobMaster} to {@link JobWorker}. */ +public class RemoteCallWorker { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteCallWorker.class); + + /** + * Call JobWorker actor to init. + * + * @param actor target JobWorker actor + * @param context JobWorker's context + * @return init result + */ + public static ObjectRef initWorker(BaseActorHandle actor, JobWorkerContext context) { + LOG.info("Call worker to initiate, actor: {}, context: {}.", actor.getId(), context); + ObjectRef result; + + // python + if (actor instanceof PyActorHandle) { + result = + ((PyActorHandle) actor) + .task(PyActorMethod.of("init", Boolean.class), context.getPythonWorkerContextBytes()) + .remote(); + } else { + // java + result = ((ActorHandle) actor).task(JobWorker::init, context).remote(); + } + + LOG.info("Finished calling worker to initiate."); + return result; + } + + /** + * Call JobWorker actor to start. + * + * @param actor target JobWorker actor + * @param checkpointId checkpoint ID to be rollback + * @return start result + */ + public static ObjectRef rollback(BaseActorHandle actor, final Long checkpointId) { + LOG.info("Call worker to start, actor: {}.", actor.getId()); + ObjectRef result; + + // python + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(checkpointId).build(); + result = + ((PyActorHandle) actor) + .task(PyActorMethod.of("rollback"), checkpointIdPb.toByteArray()) + .remote(); + } else { + // java + result = + ((ActorHandle) actor) + .task(JobWorker::rollback, checkpointId, System.currentTimeMillis()) + .remote(); + } + + LOG.info("Finished calling worker to start."); + return result; + } + + /** + * Call JobWorker actor to destroy without reconstruction. + * + * @param actor target JobWorker actor + * @return destroy result + */ + public static Boolean shutdownWithoutReconstruction(BaseActorHandle actor) { + LOG.info("Call worker to shutdown without reconstruction, actor is {}.", actor.getId()); + Boolean result = false; + + // TODO (datayjz): ray call worker to destroy + + LOG.info("Finished calling wk shutdownWithoutReconstruction, result is {}.", result); + return result; + } + + public static ObjectRef triggerCheckpoint(BaseActorHandle actor, Long barrierId) { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.Barrier barrierPb = RemoteCall.Barrier.newBuilder().setId(barrierId).build(); + return ((PyActorHandle) actor) + .task(PyActorMethod.of("commit"), barrierPb.toByteArray()) + .remote(); + } else { + // java + return ((ActorHandle) actor) + .task(JobWorker::triggerCheckpoint, barrierId) + .remote(); + } + } + + public static void clearExpiredCheckpointParallel( + List actors, Long stateCheckpointId, Long queueCheckpointId) { + if (LOG.isInfoEnabled()) { + LOG.info( + "Call worker clearExpiredCheckpoint, state checkpoint id is {}," + + " queue checkpoint id is {}.", + stateCheckpointId, + queueCheckpointId); + } + + List result = + checkpointCompleteCommonCallTwoWay( + actors, + stateCheckpointId, + queueCheckpointId, + "clear_expired_cp", + JobWorker::clearExpiredCheckpoint); + + if (LOG.isInfoEnabled()) { + result.forEach( + obj -> LOG.info("Finish call worker clearExpiredCheckpointParallel, ret is {}.", obj)); + } + } + + public static void notifyCheckpointTimeoutParallel( + List actors, Long checkpointId) { + LOG.info("Call worker notifyCheckpointTimeoutParallel, checkpoint id is {}", checkpointId); + + actors.forEach( + actor -> { + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(checkpointId).build(); + ((PyActorHandle) actor) + .task(PyActorMethod.of("notify_checkpoint_timeout"), checkpointIdPb.toByteArray()) + .remote(); + } else { + ((ActorHandle) actor) + .task(JobWorker::notifyCheckpointTimeout, checkpointId) + .remote(); + } + }); + + LOG.info("Finish call worker notifyCheckpointTimeoutParallel."); + } + + private static List checkpointCompleteCommonCallTwoWay( + List actors, + Long stateCheckpointId, + Long queueCheckpointId, + String pyFuncName, + RayFunc3 rayFunc) { + List> waitFor = + checkpointCompleteCommonCall( + actors, stateCheckpointId, queueCheckpointId, pyFuncName, rayFunc); + return Ray.get(waitFor); + } + + private static List> checkpointCompleteCommonCall( + List actors, + Long stateCheckpointId, + Long queueCheckpointId, + String pyFuncName, + RayFunc3 rayFunc) { + List> waitFor = new ArrayList<>(); + actors.forEach( + actor -> { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId stateCheckpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(stateCheckpointId).build(); + + RemoteCall.CheckpointId queueCheckpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(queueCheckpointId).build(); + waitFor.add( + ((PyActorHandle) actor) + .task( + PyActorMethod.of(pyFuncName), + stateCheckpointIdPb.toByteArray(), + queueCheckpointIdPb.toByteArray()) + .remote()); + } else { + // java + waitFor.add( + ((ActorHandle) actor).task(rayFunc, stateCheckpointId, queueCheckpointId).remote()); + } + }); + return waitFor; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java new file mode 100644 index 00000000..360c1b23 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java @@ -0,0 +1,149 @@ +package io.ray.streaming.runtime.rpc.async; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.PyActorHandle; +import io.ray.api.function.PyActorMethod; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.rpc.PbResultParser; +import io.ray.streaming.runtime.rpc.async.RemoteCallPool.Callback; +import io.ray.streaming.runtime.rpc.async.RemoteCallPool.ExceptionHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings("unchecked") +public class AsyncRemoteCaller { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncRemoteCaller.class); + private RemoteCallPool remoteCallPool = new RemoteCallPool(); + + /** + * Call JobWorker::checkIfNeedRollback async + * + * @param actor JobWorker actor + * @param callback callback function on success + * @param onException callback function on exception + */ + public void checkIfNeedRollbackAsync( + BaseActorHandle actor, Callback callback, ExceptionHandler onException) { + if (actor instanceof PyActorHandle) { + // python + remoteCallPool.bindCallback( + ((PyActorHandle) actor).task(PyActorMethod.of("check_if_need_rollback")).remote(), + (obj) -> { + byte[] res = (byte[]) obj; + callback.handle(PbResultParser.parseBoolResult(res)); + }, + onException); + } else { + // java + remoteCallPool.bindCallback( + ((ActorHandle) actor) + .task(JobWorker::checkIfNeedRollback, System.currentTimeMillis()) + .remote(), + callback, + onException); + } + } + + /** + * Call JobWorker::rollback async + * + * @param actor JobWorker actor + * @param callback callback function on success + * @param onException callback function on exception + */ + public void rollback( + BaseActorHandle actor, + final Long checkpointId, + Callback> callback, + ExceptionHandler onException) { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(checkpointId).build(); + ObjectRef call = + ((PyActorHandle) actor) + .task(PyActorMethod.of("rollback"), checkpointIdPb.toByteArray()) + .remote(); + remoteCallPool.bindCallback( + call, + obj -> callback.handle(PbResultParser.parseRollbackResult((byte[]) obj)), + onException); + } else { + // java + ObjectRef call = + ((ActorHandle) actor) + .task(JobWorker::rollback, checkpointId, System.currentTimeMillis()) + .remote(); + remoteCallPool.bindCallback( + call, + obj -> { + CallResult res = (CallResult) obj; + callback.handle(res); + }, + onException); + } + } + + /** + * Call JobWorker::rollback async in batch + * + * @param actors JobWorker actor list + * @param callback callback function on success + * @param onException callback function on exception + */ + public void batchRollback( + List actors, + final Long checkpointId, + Collection abnormalQueues, + Callback>> callback, + ExceptionHandler onException) { + List> rayCallList = new ArrayList<>(); + Map isPyActor = new HashMap<>(); + for (int i = 0; i < actors.size(); ++i) { + BaseActorHandle actor = actors.get(i); + ObjectRef call; + if (actor instanceof PyActorHandle) { + isPyActor.put(i, true); + RemoteCall.CheckpointId checkpointIdPb = + RemoteCall.CheckpointId.newBuilder().setCheckpointId(checkpointId).build(); + call = + ((PyActorHandle) actor) + .task(PyActorMethod.of("rollback"), checkpointIdPb.toByteArray()) + .remote(); + } else { + // java + call = + ((ActorHandle) actor) + .task(JobWorker::rollback, checkpointId, System.currentTimeMillis()) + .remote(); + } + rayCallList.add(call); + } + remoteCallPool.bindCallback( + rayCallList, + objList -> { + List> results = new ArrayList<>(); + for (int i = 0; i < objList.size(); ++i) { + Object obj = objList.get(i); + if (isPyActor.getOrDefault(i, false)) { + results.add(PbResultParser.parseRollbackResult((byte[]) obj)); + } else { + results.add((CallResult) obj); + } + } + callback.handle(results); + }, + onException); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java new file mode 100644 index 00000000..a60d916f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java @@ -0,0 +1,195 @@ +package io.ray.streaming.runtime.rpc.async; + +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.WaitResult; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RemoteCallPool implements Runnable { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteCallPool.class); + private static final int WAIT_TIME_MS = 5; + private static final long WARNING_PERIOD = 10000; + private final List pendingObjectBundles = new LinkedList<>(); + private Map> singletonHandlerMap = new ConcurrentHashMap<>(); + private Map>> bundleHandlerMap = + new ConcurrentHashMap<>(); + private Map> bundleExceptionHandlerMap = + new ConcurrentHashMap<>(); + private ThreadPoolExecutor callBackPool = + new ThreadPoolExecutor( + 2, + Runtime.getRuntime().availableProcessors(), + 1, + TimeUnit.MINUTES, + new LinkedBlockingQueue<>(), + new CallbackThreadFactory()); + private volatile boolean stop = false; + + public RemoteCallPool() { + Thread t = new Thread(Ray.wrapRunnable(this), "remote-pool-loop"); + t.setUncaughtExceptionHandler( + (thread, throwable) -> LOG.error("Error in remote call pool thread.", throwable)); + t.start(); + } + + @SuppressWarnings("unchecked") + public void bindCallback( + ObjectRef obj, Callback callback, ExceptionHandler onException) { + List objectRefList = Collections.singletonList(obj); + RemoteCallBundle bundle = new RemoteCallBundle(objectRefList, true); + singletonHandlerMap.put(bundle, (Callback) callback); + bundleExceptionHandlerMap.put(bundle, onException); + synchronized (pendingObjectBundles) { + pendingObjectBundles.add(bundle); + } + } + + public void bindCallback( + List> objectBundle, + Callback> callback, + ExceptionHandler onException) { + RemoteCallBundle bundle = new RemoteCallBundle(objectBundle, false); + bundleHandlerMap.put(bundle, callback); + bundleExceptionHandlerMap.put(bundle, onException); + synchronized (pendingObjectBundles) { + pendingObjectBundles.add(bundle); + } + } + + public void stop() { + stop = true; + } + + public void run() { + while (!stop) { + try { + if (pendingObjectBundles.isEmpty()) { + Thread.sleep(WAIT_TIME_MS); + continue; + } + synchronized (pendingObjectBundles) { + Iterator itr = pendingObjectBundles.iterator(); + while (itr.hasNext()) { + RemoteCallBundle bundle = itr.next(); + WaitResult waitResult = + Ray.wait(bundle.objects, bundle.objects.size(), WAIT_TIME_MS); + List> readyObjs = waitResult.getReady(); + if (readyObjs.size() != bundle.objects.size()) { + long now = System.currentTimeMillis(); + long waitingTime = now - bundle.createTime; + if (waitingTime > WARNING_PERIOD && now - bundle.lastWarnTs > WARNING_PERIOD) { + bundle.lastWarnTs = now; + LOG.warn("Bundle has being waiting for {} ms, bundle = {}.", waitingTime, bundle); + } + continue; + } + + ExceptionHandler exceptionHandler = bundleExceptionHandlerMap.get(bundle); + if (bundle.isSingletonBundle) { + callBackPool.execute( + Ray.wrapRunnable( + () -> { + try { + singletonHandlerMap.get(bundle).handle(readyObjs.get(0).get()); + singletonHandlerMap.remove(bundle); + } catch (Throwable th) { + LOG.error( + "Error when get object, objectId = {}.", + readyObjs.get(0).toString(), + th); + if (exceptionHandler != null) { + exceptionHandler.handle(th); + } + } + })); + } else { + List results = + readyObjs.stream().map(ObjectRef::get).collect(Collectors.toList()); + List resultIds = + readyObjs.stream().map(ObjectRef::toString).collect(Collectors.toList()); + callBackPool.execute( + Ray.wrapRunnable( + () -> { + try { + bundleHandlerMap.get(bundle).handle(results); + bundleHandlerMap.remove(bundle); + } catch (Throwable th) { + LOG.error("Error when get object, objectIds = {}.", resultIds, th); + if (exceptionHandler != null) { + exceptionHandler.handle(th); + } + } + })); + } + itr.remove(); + } + } + + } catch (Exception e) { + LOG.error("Exception in wait loop.", e); + } + } + LOG.info("Wait loop finished."); + } + + @FunctionalInterface + public interface ExceptionHandler { + + void handle(T object); + } + + @FunctionalInterface + public interface Callback { + + void handle(T object) throws Throwable; + } + + private static class RemoteCallBundle { + + List> objects; + boolean isSingletonBundle; + long lastWarnTs = System.currentTimeMillis(); + long createTime = System.currentTimeMillis(); + + RemoteCallBundle(List> objects, boolean isSingletonBundle) { + this.objects = objects; + this.isSingletonBundle = isSingletonBundle; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("["); + objects.forEach(rayObj -> sb.append(rayObj.toString()).append(",")); + sb.append("]"); + return sb.toString(); + } + } + + static class CallbackThreadFactory implements ThreadFactory { + + private AtomicInteger cnt = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setUncaughtExceptionHandler((thread, throwable) -> LOG.error("Callback err.", throwable)); + t.setName("callback-thread-" + cnt.getAndIncrement()); + return t; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java new file mode 100644 index 00000000..ee987146 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java @@ -0,0 +1,62 @@ +package io.ray.streaming.runtime.serialization; + +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import java.util.Arrays; +import java.util.List; + +/** + * A serializer for cross-lang serialization between java/python. TODO implements a more + * sophisticated serialization framework + */ +public class CrossLangSerializer implements Serializer { + + private static final byte RECORD_TYPE_ID = 0; + private static final byte KEY_RECORD_TYPE_ID = 1; + + private MsgPackSerializer msgPackSerializer = new MsgPackSerializer(); + + public byte[] serialize(Object object) { + Record record = (Record) object; + Object value = record.getValue(); + Class clz = record.getClass(); + if (clz == Record.class) { + return msgPackSerializer.serialize(Arrays.asList(RECORD_TYPE_ID, record.getStream(), value)); + } else if (clz == KeyRecord.class) { + KeyRecord keyRecord = (KeyRecord) record; + Object key = keyRecord.getKey(); + return msgPackSerializer.serialize( + Arrays.asList(KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value)); + } else { + throw new UnsupportedOperationException( + String.format("Serialize %s is unsupported.", record)); + } + } + + @SuppressWarnings("unchecked") + public Record deserialize(byte[] bytes) { + List list = (List) msgPackSerializer.deserialize(bytes); + Byte typeId = (Byte) list.get(0); + switch (typeId) { + case RECORD_TYPE_ID: + { + String stream = (String) list.get(1); + Object value = list.get(2); + Record record = new Record(value); + record.setStream(stream); + return record; + } + case KEY_RECORD_TYPE_ID: + { + String stream = (String) list.get(1); + Object key = list.get(2); + Object value = list.get(3); + KeyRecord keyRecord = new KeyRecord(key, value); + keyRecord.setStream(stream); + return keyRecord; + } + default: + throw new UnsupportedOperationException("Unsupported type " + typeId); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java new file mode 100644 index 00000000..42072408 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java @@ -0,0 +1,16 @@ +package io.ray.streaming.runtime.serialization; + +import io.ray.runtime.serializer.FstSerializer; + +public class JavaSerializer implements Serializer { + + @Override + public byte[] serialize(Object object) { + return FstSerializer.encode(object); + } + + @Override + public T deserialize(byte[] bytes) { + return FstSerializer.decode(bytes); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java new file mode 100644 index 00000000..2fc9a2c3 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java @@ -0,0 +1,127 @@ +package io.ray.streaming.runtime.serialization; + +import com.google.common.io.BaseEncoding; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessageUnpacker; +import org.msgpack.value.ArrayValue; +import org.msgpack.value.FloatValue; +import org.msgpack.value.IntegerValue; +import org.msgpack.value.MapValue; +import org.msgpack.value.Value; + +public class MsgPackSerializer { + + public byte[] serialize(Object obj) { + MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + serialize(obj, packer); + return packer.toByteArray(); + } + + private void serialize(Object obj, MessageBufferPacker packer) { + try { + if (obj == null) { + packer.packNil(); + } else { + Class clz = obj.getClass(); + if (clz == Boolean.class) { + packer.packBoolean((Boolean) obj); + } else if (clz == Byte.class) { + packer.packByte((Byte) obj); + } else if (clz == Short.class) { + packer.packShort((Short) obj); + } else if (clz == Integer.class) { + packer.packInt((Integer) obj); + } else if (clz == Long.class) { + packer.packLong((Long) obj); + } else if (clz == Double.class) { + packer.packDouble((Double) obj); + } else if (clz == byte[].class) { + byte[] bytes = (byte[]) obj; + packer.packBinaryHeader(bytes.length); + packer.writePayload(bytes); + } else if (clz == String.class) { + packer.packString((String) obj); + } else if (obj instanceof Collection) { + Collection collection = (Collection) (obj); + packer.packArrayHeader(collection.size()); + for (Object o : collection) { + serialize(o, packer); + } + } else if (obj instanceof Map) { + Map map = (Map) (obj); + packer.packMapHeader(map.size()); + for (Object o : map.entrySet()) { + Map.Entry e = (Map.Entry) o; + serialize(e.getKey(), packer); + serialize(e.getValue(), packer); + } + } else { + throw new UnsupportedOperationException("Unsupported type " + clz); + } + } + } catch (Exception e) { + throw new RuntimeException("Serialize error for object " + obj, e); + } + } + + public Object deserialize(byte[] bytes) { + try { + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bytes); + return deserialize(unpacker.unpackValue()); + } catch (Exception e) { + String hex = BaseEncoding.base16().lowerCase().encode(bytes); + throw new RuntimeException("Deserialize error: " + hex, e); + } + } + + private Object deserialize(Value value) { + switch (value.getValueType()) { + case NIL: + return null; + case BOOLEAN: + return value.asBooleanValue().getBoolean(); + case INTEGER: + IntegerValue iv = value.asIntegerValue(); + if (iv.isInByteRange()) { + return iv.toByte(); + } else if (iv.isInShortRange()) { + return iv.toShort(); + } else if (iv.isInIntRange()) { + return iv.toInt(); + } else if (iv.isInLongRange()) { + return iv.toLong(); + } else { + return iv.toBigInteger(); + } + case FLOAT: + FloatValue fv = value.asFloatValue(); + return fv.toDouble(); + case STRING: + return value.asStringValue().asString(); + case BINARY: + return value.asBinaryValue().asByteArray(); + case ARRAY: + ArrayValue arrayValue = value.asArrayValue(); + List list = new ArrayList<>(arrayValue.size()); + for (Value elem : arrayValue) { + list.add(deserialize(elem)); + } + return list; + case MAP: + MapValue mapValue = value.asMapValue(); + Map map = new HashMap<>(); + for (Map.Entry entry : mapValue.entrySet()) { + map.put(deserialize(entry.getKey()), deserialize(entry.getValue())); + } + return map; + default: + throw new UnsupportedOperationException("Unsupported type " + value.getValueType()); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java new file mode 100644 index 00000000..0be25b9a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java @@ -0,0 +1,12 @@ +package io.ray.streaming.runtime.serialization; + +public interface Serializer { + + byte CROSS_LANG_TYPE_ID = 0; + byte JAVA_TYPE_ID = 1; + byte PYTHON_TYPE_ID = 2; + + byte[] serialize(Object object); + + T deserialize(byte[] bytes); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java new file mode 100644 index 00000000..2e964cb5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java @@ -0,0 +1,166 @@ +package io.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; +import io.ray.runtime.actor.LocalModeActorHandle; +import io.ray.runtime.actor.NativeJavaActorHandle; +import io.ray.runtime.actor.NativePyActorHandle; +import io.ray.runtime.functionmanager.FunctionDescriptor; +import io.ray.runtime.functionmanager.JavaFunctionDescriptor; +import io.ray.runtime.functionmanager.PyFunctionDescriptor; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.ArrayList; +import java.util.List; + +/** Save channel initial parameters needed by DataWriter/DataReader. */ +public class ChannelCreationParametersBuilder { + + public static class Parameter { + + private ActorId actorId; + private FunctionDescriptor asyncFunctionDescriptor; + private FunctionDescriptor syncFunctionDescriptor; + + public void setActorId(ActorId actorId) { + this.actorId = actorId; + } + + public void setAsyncFunctionDescriptor(FunctionDescriptor asyncFunctionDescriptor) { + this.asyncFunctionDescriptor = asyncFunctionDescriptor; + } + + public void setSyncFunctionDescriptor(FunctionDescriptor syncFunctionDescriptor) { + this.syncFunctionDescriptor = syncFunctionDescriptor; + } + + public String toString() { + String language = + asyncFunctionDescriptor instanceof JavaFunctionDescriptor ? "Java" : "Python"; + return "Language: " + + language + + " Desc: " + + asyncFunctionDescriptor.toList() + + " " + + syncFunctionDescriptor.toList(); + } + + // Get actor id in bytes, called from jni. + public byte[] getActorIdBytes() { + return actorId.getBytes(); + } + + // Get async function descriptor, called from jni. + public FunctionDescriptor getAsyncFunctionDescriptor() { + return asyncFunctionDescriptor; + } + + // Get sync function descriptor, called from jni. + public FunctionDescriptor getSyncFunctionDescriptor() { + return syncFunctionDescriptor; + } + } + + private List parameters; + + // function descriptors of direct call entry point for Java workers + private static JavaFunctionDescriptor javaReaderAsyncFuncDesc = + new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"); + private static JavaFunctionDescriptor javaReaderSyncFuncDesc = + new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B"); + private static JavaFunctionDescriptor javaWriterAsyncFuncDesc = + new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"); + private static JavaFunctionDescriptor javaWriterSyncFuncDesc = + new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"); + // function descriptors of direct call entry point for Python workers + private static PyFunctionDescriptor pyReaderAsyncFunctionDesc = + new PyFunctionDescriptor("ray.streaming.runtime.worker", "JobWorker", "on_reader_message"); + private static PyFunctionDescriptor pyReaderSyncFunctionDesc = + new PyFunctionDescriptor( + "ray.streaming.runtime.worker", "JobWorker", "on_reader_message_sync"); + private static PyFunctionDescriptor pyWriterAsyncFunctionDesc = + new PyFunctionDescriptor("ray.streaming.runtime.worker", "JobWorker", "on_writer_message"); + private static PyFunctionDescriptor pyWriterSyncFunctionDesc = + new PyFunctionDescriptor( + "ray.streaming.runtime.worker", "JobWorker", "on_writer_message_sync"); + + public ChannelCreationParametersBuilder() {} + + public static void setJavaReaderFunctionDesc( + JavaFunctionDescriptor asyncFunc, JavaFunctionDescriptor syncFunc) { + javaReaderAsyncFuncDesc = asyncFunc; + javaReaderSyncFuncDesc = syncFunc; + } + + public static void setJavaWriterFunctionDesc( + JavaFunctionDescriptor asyncFunc, JavaFunctionDescriptor syncFunc) { + javaWriterAsyncFuncDesc = asyncFunc; + javaWriterSyncFuncDesc = syncFunc; + } + + public ChannelCreationParametersBuilder buildInputQueueParameters( + List queues, List actors) { + return buildParameters( + queues, + actors, + javaWriterAsyncFuncDesc, + javaWriterSyncFuncDesc, + pyWriterAsyncFunctionDesc, + pyWriterSyncFunctionDesc); + } + + public ChannelCreationParametersBuilder buildOutputQueueParameters( + List queues, List actors) { + return buildParameters( + queues, + actors, + javaReaderAsyncFuncDesc, + javaReaderSyncFuncDesc, + pyReaderAsyncFunctionDesc, + pyReaderSyncFunctionDesc); + } + + private ChannelCreationParametersBuilder buildParameters( + List queues, + List actors, + JavaFunctionDescriptor javaAsyncFunctionDesc, + JavaFunctionDescriptor javaSyncFunctionDesc, + PyFunctionDescriptor pyAsyncFunctionDesc, + PyFunctionDescriptor pySyncFunctionDesc) { + parameters = new ArrayList<>(queues.size()); + + for (int i = 0; i < queues.size(); ++i) { + String queue = queues.get(i); + BaseActorHandle actor = actors.get(i); + Parameter parameter = new Parameter(); + Preconditions.checkArgument(actor != null); + parameter.setActorId(actor.getId()); + /// LocalModeRayActor used in single-process mode. + if (actor instanceof NativeJavaActorHandle || actor instanceof LocalModeActorHandle) { + parameter.setAsyncFunctionDescriptor(javaAsyncFunctionDesc); + parameter.setSyncFunctionDescriptor(javaSyncFunctionDesc); + } else if (actor instanceof NativePyActorHandle) { + parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc); + parameter.setSyncFunctionDescriptor(pySyncFunctionDesc); + } else { + throw new IllegalArgumentException("Invalid actor type"); + } + parameters.add(parameter); + } + + return this; + } + + // Called from jni + public List getParameters() { + return parameters; + } + + public String toString() { + StringBuilder str = new StringBuilder(); + for (Parameter param : parameters) { + str.append(param.toString()); + } + return str.toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java new file mode 100644 index 00000000..ff3c62fe --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java @@ -0,0 +1,379 @@ +package io.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.streaming.runtime.config.StreamingWorkerConfig; +import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; +import io.ray.streaming.runtime.transfer.channel.ChannelUtils; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.transfer.message.BarrierMessage; +import io.ray.streaming.runtime.transfer.message.ChannelMessage; +import io.ray.streaming.runtime.transfer.message.DataMessage; +import io.ray.streaming.runtime.util.Platform; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataReader is wrapper of streaming c++ DataReader, which read data from channels of upstream + * workers + */ +public class DataReader { + + private static final Logger LOG = LoggerFactory.getLogger(DataReader.class); + + private long nativeReaderPtr; + // params set by getBundleNative: bundle data address + size + private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); + // We use direct buffer to reduce gc overhead and memory copy. + private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); + private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); + + private final Map queueCreationStatusMap = new HashMap<>(); + private Queue buf = new LinkedList<>(); + + { + getBundleParams.order(ByteOrder.nativeOrder()); + bundleData.order(ByteOrder.nativeOrder()); + bundleMeta.order(ByteOrder.nativeOrder()); + } + + /** + * @param inputChannels input channels ids + * @param fromActors upstream input actors + * @param workerConfig configuration + */ + public DataReader( + List inputChannels, + List fromActors, + Map checkpoints, + StreamingWorkerConfig workerConfig) { + Preconditions.checkArgument(inputChannels.size() > 0); + Preconditions.checkArgument(inputChannels.size() == fromActors.size()); + ChannelCreationParametersBuilder initialParameters = + new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors); + byte[][] inputChannelsBytes = + inputChannels.stream().map(ChannelId::idStrToBytes).toArray(byte[][]::new); + + // get sequence ID and message ID from OffsetInfo + long[] msgIds = new long[inputChannels.size()]; + for (int i = 0; i < inputChannels.size(); i++) { + String channelId = inputChannels.get(i); + if (!checkpoints.containsKey(channelId)) { + msgIds[i] = 0; + continue; + } + msgIds[i] = checkpoints.get(inputChannels.get(i)).getStreamingMsgId(); + } + long timerInterval = workerConfig.transferConfig.readerTimerIntervalMs(); + TransferChannelType channelType = workerConfig.transferConfig.channelType(); + boolean isMock = false; + if (TransferChannelType.MEMORY_CHANNEL == channelType) { + isMock = true; + } + + // create native reader + List creationStatus = new ArrayList<>(); + this.nativeReaderPtr = + createDataReaderNative( + initialParameters, + inputChannelsBytes, + msgIds, + timerInterval, + creationStatus, + ChannelUtils.toNativeConf(workerConfig), + isMock); + for (int i = 0; i < inputChannels.size(); ++i) { + queueCreationStatusMap.put( + inputChannels.get(i), ChannelCreationStatus.fromInt(creationStatus.get(i))); + } + LOG.info( + "Create DataReader succeed for worker: {}, creation status={}.", + workerConfig.workerInternalConfig.workerName(), + queueCreationStatusMap); + } + + private static native long createDataReaderNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] inputChannels, + long[] msgIds, + long timerInterval, + List creationStatus, + byte[] configBytes, + boolean isMock); + + /** + * Read message from input channels, if timeout, return null. + * + * @param timeoutMillis timeout + * @return message or null + */ + public ChannelMessage read(long timeoutMillis) { + if (buf.isEmpty()) { + getBundle(timeoutMillis); + // if bundle not empty. empty message still has data size + seqId + msgId + if (bundleData.position() < bundleData.limit()) { + BundleMeta bundleMeta = new BundleMeta(this.bundleMeta); + String channelID = bundleMeta.getChannelID(); + long timestamp = bundleMeta.getBundleTs(); + // barrier + if (bundleMeta.getBundleType() == DataBundleType.BARRIER) { + buf.offer(getBarrier(bundleData, channelID, timestamp)); + } else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) { + for (int i = 0; i < bundleMeta.getMessageListSize(); i++) { + buf.offer(getDataMessage(bundleData, channelID, timestamp)); + } + } + } + } + if (buf.isEmpty()) { + return null; + } + return buf.poll(); + } + + public ChannelRecoverInfo getQueueRecoverInfo() { + return new ChannelRecoverInfo(queueCreationStatusMap); + } + + private String getQueueIdString(ByteBuffer buffer) { + byte[] bytes = new byte[ChannelId.ID_LENGTH]; + buffer.get(bytes); + return ChannelId.idBytesToStr(bytes); + } + + private BarrierMessage getBarrier(ByteBuffer bundleData, String channelID, long timestamp) { + ByteBuffer offsetsInfoBytes = ByteBuffer.wrap(getOffsetsInfoNative(nativeReaderPtr)); + offsetsInfoBytes.order(ByteOrder.nativeOrder()); + BarrierOffsetInfo offsetInfo = new BarrierOffsetInfo(offsetsInfoBytes); + DataMessage message = getDataMessage(bundleData, channelID, timestamp); + BarrierItem barrierItem = new BarrierItem(message, offsetInfo); + return new BarrierMessage( + message.getMsgId(), + message.getTimestamp(), + message.getChannelId(), + barrierItem.getData(), + barrierItem.getGlobalBarrierId(), + barrierItem.getBarrierOffsetInfo().getQueueOffsetInfo()); + } + + private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) { + int dataSize = bundleData.getInt(); + // msgId + long msgId = bundleData.getLong(); + // msgType + bundleData.getInt(); + // make `data.capacity() == data.remaining()`, because some code used `capacity()` + // rather than `remaining()` + int position = bundleData.position(); + int limit = bundleData.limit(); + bundleData.limit(position + dataSize); + ByteBuffer data = bundleData.slice(); + bundleData.limit(limit); + bundleData.position(position + dataSize); + return new DataMessage(data, timestamp, msgId, channelID); + } + + private void getBundle(long timeoutMillis) { + getBundleNative( + nativeReaderPtr, + timeoutMillis, + Platform.getAddress(getBundleParams), + Platform.getAddress(bundleMeta)); + bundleMeta.rewind(); + long bundleAddress = getBundleParams.getLong(0); + int bundleSize = getBundleParams.getInt(8); + // This has better performance than NewDirectBuffer or set address/capacity in jni. + Platform.wrapDirectBuffer(bundleData, bundleAddress, bundleSize); + } + + /** Stop reader */ + public void stop() { + stopReaderNative(nativeReaderPtr); + } + + /** Close reader to release resource */ + public void close() { + if (nativeReaderPtr == 0) { + return; + } + LOG.info("Closing DataReader."); + closeReaderNative(nativeReaderPtr); + nativeReaderPtr = 0; + LOG.info("Finish closing DataReader."); + } + + private native void getBundleNative( + long nativeReaderPtr, long timeoutMillis, long params, long metaAddress); + + private native byte[] getOffsetsInfoNative(long nativeQueueConsumerPtr); + + private native void stopReaderNative(long nativeReaderPtr); + + private native void closeReaderNative(long nativeReaderPtr); + + enum DataBundleType { + EMPTY(1), + BARRIER(2), + BUNDLE(3); + + int code; + + DataBundleType(int code) { + this.code = code; + } + } + + public enum BarrierType { + GLOBAL_BARRIER(0); + private int code; + + BarrierType(int code) { + this.code = code; + } + } + + class BundleMeta { + + // kMessageBundleHeaderSize + kUniqueIDSize: + // magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b) + // + bundleType(4b) + rawBundleSize(4b) + channelID + static final int LENGTH = 4 + 8 + 8 + 4 + 4 + 4 + ChannelId.ID_LENGTH; + private int magicNum; + private long bundleTs; + private long lastMessageId; + private int messageListSize; + private DataBundleType bundleType; + private String channelID; + private int rawBundleSize; + + BundleMeta(ByteBuffer buffer) { + // StreamingMessageBundleMeta Deserialization + // magicNum + magicNum = buffer.getInt(); + // messageBundleTs + bundleTs = buffer.getLong(); + // lastOffsetSeqId + lastMessageId = buffer.getLong(); + messageListSize = buffer.getInt(); + int typeInt = buffer.getInt(); + if (DataBundleType.BUNDLE.code == typeInt) { + bundleType = DataBundleType.BUNDLE; + } else if (DataBundleType.BARRIER.code == typeInt) { + bundleType = DataBundleType.BARRIER; + } else { + bundleType = DataBundleType.EMPTY; + } + // rawBundleSize + rawBundleSize = buffer.getInt(); + channelID = getQueueIdString(buffer); + } + + public int getMagicNum() { + return magicNum; + } + + public long getBundleTs() { + return bundleTs; + } + + public long getLastMessageId() { + return lastMessageId; + } + + public int getMessageListSize() { + return messageListSize; + } + + public DataBundleType getBundleType() { + return bundleType; + } + + public String getChannelID() { + return channelID; + } + + public int getRawBundleSize() { + return rawBundleSize; + } + } + + class BarrierOffsetInfo { + + private int queueSize; + private Map queueOffsetInfo; + + public BarrierOffsetInfo(ByteBuffer buffer) { + // deserialization offset + queueSize = buffer.getInt(); + queueOffsetInfo = new HashMap<>(queueSize); + for (int i = 0; i < queueSize; ++i) { + String qid = getQueueIdString(buffer); + long streamingMsgId = buffer.getLong(); + queueOffsetInfo.put(qid, new OffsetInfo(streamingMsgId)); + } + } + + public int getQueueSize() { + return queueSize; + } + + public Map getQueueOffsetInfo() { + return queueOffsetInfo; + } + } + + class BarrierItem { + + BarrierOffsetInfo barrierOffsetInfo; + private long msgId; + private BarrierType barrierType; + private long globalBarrierId; + private ByteBuffer data; + + public BarrierItem(DataMessage message, BarrierOffsetInfo barrierOffsetInfo) { + this.barrierOffsetInfo = barrierOffsetInfo; + msgId = message.getMsgId(); + ByteBuffer buffer = message.body(); + // c++ use native order, so use native order here. + buffer.order(ByteOrder.nativeOrder()); + int barrierTypeInt = buffer.getInt(); + globalBarrierId = buffer.getLong(); + // dataSize includes: barrier type(32 bit), globalBarrierId, data + data = buffer.slice(); + data.order(ByteOrder.nativeOrder()); + buffer.position(buffer.limit()); + barrierType = BarrierType.GLOBAL_BARRIER; + } + + public long getBarrierMsgId() { + return msgId; + } + + public BarrierType getBarrierType() { + return barrierType; + } + + public long getGlobalBarrierId() { + return globalBarrierId; + } + + public ByteBuffer getData() { + return data; + } + + public BarrierOffsetInfo getBarrierOffsetInfo() { + return barrierOffsetInfo; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java new file mode 100644 index 00000000..039b22b7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java @@ -0,0 +1,181 @@ +package io.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.streaming.runtime.config.StreamingWorkerConfig; +import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.channel.ChannelUtils; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.util.Platform; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** DataWriter is a wrapper of streaming c++ DataWriter, which sends data to downstream workers */ +public class DataWriter { + + private static final Logger LOG = LoggerFactory.getLogger(DataWriter.class); + + private long nativeWriterPtr; + private ByteBuffer buffer = ByteBuffer.allocateDirect(0); + private long bufferAddress; + private List outputChannels; + + { + ensureBuffer(0); + } + + /** + * @param outputChannels output channels ids + * @param toActors downstream output actors + * @param workerConfig configuration + * @param checkpoints offset of each channels + */ + public DataWriter( + List outputChannels, + List toActors, + Map checkpoints, + StreamingWorkerConfig workerConfig) { + Preconditions.checkArgument(!outputChannels.isEmpty()); + Preconditions.checkArgument(outputChannels.size() == toActors.size()); + this.outputChannels = outputChannels; + + ChannelCreationParametersBuilder initialParameters = + new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); + + byte[][] outputChannelsBytes = + outputChannels.stream().map(ChannelId::idStrToBytes).toArray(byte[][]::new); + long channelSize = workerConfig.transferConfig.channelSize(); + + // load message id from checkpoints + long[] msgIds = new long[outputChannels.size()]; + for (int i = 0; i < outputChannels.size(); i++) { + String channelId = outputChannels.get(i); + if (!checkpoints.containsKey(channelId)) { + msgIds[i] = 0; + continue; + } + msgIds[i] = checkpoints.get(channelId).getStreamingMsgId(); + } + TransferChannelType channelType = workerConfig.transferConfig.channelType(); + boolean isMock = false; + if (TransferChannelType.MEMORY_CHANNEL == channelType) { + isMock = true; + } + this.nativeWriterPtr = + createWriterNative( + initialParameters, + outputChannelsBytes, + msgIds, + channelSize, + ChannelUtils.toNativeConf(workerConfig), + isMock); + LOG.info( + "Create DataWriter succeed for worker: {}.", + workerConfig.workerInternalConfig.workerName()); + } + + private static native long createWriterNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] outputQueueIds, + long[] msgIds, + long channelSize, + byte[] confBytes, + boolean isMock); + + /** + * Write msg into the specified channel + * + * @param id channel id + * @param item message item data section is specified by [position, limit). + */ + public void write(ChannelId id, ByteBuffer item) { + int size = item.remaining(); + ensureBuffer(size); + buffer.clear(); + buffer.put(item); + writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size); + } + + /** + * Write msg into the specified channels + * + * @param ids channel ids + * @param item message item data section is specified by [position, limit). item doesn't have to + * be a direct buffer. + */ + public void write(Set ids, ByteBuffer item) { + int size = item.remaining(); + ensureBuffer(size); + for (ChannelId id : ids) { + buffer.clear(); + buffer.put(item.duplicate()); + writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size); + } + } + + private void ensureBuffer(int size) { + if (buffer.capacity() < size) { + buffer = ByteBuffer.allocateDirect(size); + buffer.order(ByteOrder.nativeOrder()); + bufferAddress = Platform.getAddress(buffer); + } + } + + public Map getOutputCheckpoints() { + long[] msgId = getOutputMsgIdNative(nativeWriterPtr); + Map res = new HashMap<>(outputChannels.size()); + for (int i = 0; i < outputChannels.size(); ++i) { + res.put(outputChannels.get(i), new OffsetInfo(msgId[i])); + } + LOG.info("got output points, {}.", res); + return res; + } + + public void broadcastBarrier(long checkpointId, ByteBuffer attach) { + LOG.info("Broadcast barrier, cpId={}.", checkpointId); + Preconditions.checkArgument(attach.order() == ByteOrder.nativeOrder()); + broadcastBarrierNative(nativeWriterPtr, checkpointId, attach.array()); + } + + public void clearCheckpoint(long checkpointId) { + LOG.info("Producer clear checkpoint, checkpointId={}.", checkpointId); + clearCheckpointNative(nativeWriterPtr, checkpointId); + } + + /** stop writer */ + public void stop() { + stopWriterNative(nativeWriterPtr); + } + + /** close writer to release resources */ + public void close() { + if (nativeWriterPtr == 0) { + return; + } + LOG.info("Closing data writer."); + closeWriterNative(nativeWriterPtr); + nativeWriterPtr = 0; + LOG.info("Finish closing data writer."); + } + + private native long writeMessageNative( + long nativeQueueProducerPtr, long nativeIdPtr, long address, int size); + + private native void stopWriterNative(long nativeQueueProducerPtr); + + private native void closeWriterNative(long nativeQueueProducerPtr); + + private native long[] getOutputMsgIdNative(long nativeQueueProducerPtr); + + private native void broadcastBarrierNative( + long nativeQueueProducerPtr, long checkpointId, byte[] data); + + private native void clearCheckpointNative(long nativeQueueProducerPtr, long checkpointId); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java new file mode 100644 index 00000000..816273d8 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/TransferHandler.java @@ -0,0 +1,52 @@ +package io.ray.streaming.runtime.transfer; + +import io.ray.runtime.util.BinaryFileUtil; +import io.ray.runtime.util.JniUtils; + +/** + * TransferHandler is used for handle direct call based data transfer between workers. + * TransferHandler is used by streaming queue for data transfer. + */ +public class TransferHandler { + + static { + JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true); + JniUtils.loadLibrary("streaming_java"); + } + + private long writerClientNative; + private long readerClientNative; + + public TransferHandler() { + writerClientNative = createWriterClientNative(); + readerClientNative = createReaderClientNative(); + } + + public void onWriterMessage(byte[] buffer) { + handleWriterMessageNative(writerClientNative, buffer); + } + + public byte[] onWriterMessageSync(byte[] buffer) { + return handleWriterMessageSyncNative(writerClientNative, buffer); + } + + public void onReaderMessage(byte[] buffer) { + handleReaderMessageNative(readerClientNative, buffer); + } + + public byte[] onReaderMessageSync(byte[] buffer) { + return handleReaderMessageSyncNative(readerClientNative, buffer); + } + + private native long createWriterClientNative(); + + private native long createReaderClientNative(); + + private native void handleWriterMessageNative(long handler, byte[] buffer); + + private native byte[] handleWriterMessageSyncNative(long handler, byte[] buffer); + + private native void handleReaderMessageNative(long handler, byte[] buffer); + + private native byte[] handleReaderMessageSyncNative(long handler, byte[] buffer); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java new file mode 100644 index 00000000..731031d6 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java @@ -0,0 +1,178 @@ +package io.ray.streaming.runtime.transfer.channel; + +import com.google.common.base.FinalizablePhantomReference; +import com.google.common.base.FinalizableReferenceQueue; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; +import com.google.common.io.BaseEncoding; +import io.ray.api.id.ObjectId; +import java.lang.ref.Reference; +import java.nio.ByteBuffer; +import java.util.Random; +import java.util.Set; +import sun.nio.ch.DirectBuffer; + +/** + * ChannelID is used to identify a transfer channel between a upstream worker and downstream worker. + */ +public class ChannelId { + + public static final int ID_LENGTH = ObjectId.LENGTH; + private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue(); + // This ensures that the FinalizablePhantomReference itself is not garbage-collected. + private static final Set> references = Sets.newConcurrentHashSet(); + + private final byte[] bytes; + private final String strId; + private final ByteBuffer buffer; + private final long address; + private final long nativeIdPtr; + + private ChannelId(String strId, byte[] idBytes) { + this.strId = strId; + this.bytes = idBytes; + ByteBuffer directBuffer = ByteBuffer.allocateDirect(ID_LENGTH); + directBuffer.put(bytes); + directBuffer.rewind(); + this.buffer = directBuffer; + this.address = ((DirectBuffer) (buffer)).address(); + long nativeIdPtr = 0; + nativeIdPtr = createNativeId(address); + this.nativeIdPtr = nativeIdPtr; + } + + private static native long createNativeId(long idAddress); + + private static native void destroyNativeId(long nativeIdPtr); + + /** @param id hex string representation of channel id */ + public static ChannelId from(String id) { + return from(id, ChannelId.idStrToBytes(id)); + } + + /** @param idBytes bytes representation of channel id */ + public static ChannelId from(byte[] idBytes) { + return from(idBytesToStr(idBytes), idBytes); + } + + private static ChannelId from(String strID, byte[] idBytes) { + ChannelId id = new ChannelId(strID, idBytes); + long nativeIdPtr = id.nativeIdPtr; + if (nativeIdPtr != 0) { + Reference reference = + new FinalizablePhantomReference(id, REFERENCE_QUEUE) { + @Override + public void finalizeReferent() { + destroyNativeId(nativeIdPtr); + references.remove(this); + } + }; + references.add(reference); + } + return id; + } + + /** Returns a random channel id string */ + public static String genRandomIdStr() { + StringBuilder sb = new StringBuilder(); + Random random = new Random(); + for (int i = 0; i < ChannelId.ID_LENGTH * 2; ++i) { + sb.append((char) (random.nextInt(6) + 'A')); + } + return sb.toString(); + } + + /** + * Generate channel name, which will be {@link ChannelId#ID_LENGTH} character + * + * @param fromTaskId upstream task id + * @param toTaskId downstream task id + * @return channel name + */ + public static String genIdStr(int fromTaskId, int toTaskId, long ts) { + /* + | Head | Timestamp | Empty | From | To | padding | + | 8 bytes | 4bytes | 4bytes| 2bytes| 2bytes | | + */ + Preconditions.checkArgument( + fromTaskId < Short.MAX_VALUE, + "fromTaskId %s is larger than %s", + fromTaskId, + Short.MAX_VALUE); + Preconditions.checkArgument( + toTaskId < Short.MAX_VALUE, "toTaskId %s is larger than %s", fromTaskId, Short.MAX_VALUE); + byte[] channelName = new byte[ID_LENGTH]; + + for (int i = 11; i >= 8; i--) { + channelName[i] = (byte) (ts & 0xff); + ts >>= 8; + } + + channelName[16] = (byte) ((fromTaskId & 0xffff) >> 8); + channelName[17] = (byte) (fromTaskId & 0xff); + channelName[18] = (byte) ((toTaskId & 0xffff) >> 8); + channelName[19] = (byte) (toTaskId & 0xff); + + return ChannelId.idBytesToStr(channelName); + } + + /** + * @param id hex string representation of channel id + * @return bytes representation of channel id + */ + public static byte[] idStrToBytes(String id) { + byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase()); + assert idBytes.length == ChannelId.ID_LENGTH; + return idBytes; + } + + /** + * @param id bytes representation of channel id + * @return hex string representation of channel id + */ + public static String idBytesToStr(byte[] id) { + assert id.length == ChannelId.ID_LENGTH; + return BaseEncoding.base16().encode(id).toLowerCase(); + } + + public byte[] getBytes() { + return bytes; + } + + public ByteBuffer getBuffer() { + return buffer; + } + + public long getAddress() { + return address; + } + + public long getNativeIdPtr() { + if (nativeIdPtr == 0) { + throw new IllegalStateException("native ID not available"); + } + return nativeIdPtr; + } + + @Override + public String toString() { + return strId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ChannelId that = (ChannelId) o; + return strId.equals(that.strId); + } + + @Override + public int hashCode() { + return strId.hashCode(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java new file mode 100644 index 00000000..fd785cdb --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java @@ -0,0 +1,57 @@ +package io.ray.streaming.runtime.transfer.channel; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ChannelRecoverInfo implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(ChannelRecoverInfo.class); + public Map queueCreationStatusMap; + + public ChannelRecoverInfo(Map queueCreationStatusMap) { + this.queueCreationStatusMap = queueCreationStatusMap; + } + + public Set getDataLostQueues() { + Set dataLostQueues = new HashSet<>(); + queueCreationStatusMap.forEach( + (q, status) -> { + if (status.equals(ChannelCreationStatus.DataLost)) { + dataLostQueues.add(q); + } + }); + return dataLostQueues; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("dataLostQueues", getDataLostQueues()).toString(); + } + + public enum ChannelCreationStatus { + FreshStarted(0), + PullOk(1), + Timeout(2), + DataLost(3); + + private int id; + + ChannelCreationStatus(int id) { + this.id = id; + } + + public static ChannelCreationStatus fromInt(int id) { + for (ChannelCreationStatus status : ChannelCreationStatus.values()) { + if (status.id == id) { + return status; + } + } + return null; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java new file mode 100644 index 00000000..d2954352 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java @@ -0,0 +1,68 @@ +package io.ray.streaming.runtime.transfer.channel; + +import io.ray.streaming.runtime.config.StreamingWorkerConfig; +import io.ray.streaming.runtime.generated.Streaming; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ChannelUtils { + + private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class); + + public static byte[] toNativeConf(StreamingWorkerConfig workerConfig) { + Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder(); + + // job name + String jobName = workerConfig.commonConfig.jobName(); + if (!StringUtils.isEmpty(jobName)) { + builder.setJobName(workerConfig.commonConfig.jobName()); + } + + // worker name + String workerName = workerConfig.workerInternalConfig.workerName(); + if (!StringUtils.isEmpty(workerName)) { + builder.setWorkerName(workerName); + } + + // operator name + String operatorName = workerConfig.workerInternalConfig.workerOperatorName(); + if (!StringUtils.isEmpty(operatorName)) { + builder.setOpName(operatorName); + } + + // ring buffer capacity + int ringBufferCapacity = workerConfig.transferConfig.ringBufferCapacity(); + if (ringBufferCapacity != -1) { + builder.setRingBufferCapacity(ringBufferCapacity); + } + + // empty message interval + int emptyMsgInterval = workerConfig.transferConfig.emptyMsgInterval(); + if (emptyMsgInterval != -1) { + builder.setEmptyMessageInterval(emptyMsgInterval); + } + + // flow control type + int flowControlType = workerConfig.transferConfig.flowControlType(); + if (flowControlType != -1) { + builder.setFlowControlType(Streaming.FlowControlType.forNumber(flowControlType)); + } + + // writer consumed step + int writerConsumedStep = workerConfig.transferConfig.writerConsumedStep(); + if (writerConsumedStep != -1) { + builder.setWriterConsumedStep(writerConsumedStep); + } + + // reader consumed step + int readerConsumedStep = workerConfig.transferConfig.readerConsumedStep(); + if (readerConsumedStep != -1) { + builder.setReaderConsumedStep(readerConsumedStep); + } + + Streaming.StreamingConfig streamingConf = builder.build(); + LOGGER.info("Streaming native conf {}", streamingConf.toString()); + return streamingConf.toByteArray(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java new file mode 100644 index 00000000..3cdfbf81 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java @@ -0,0 +1,27 @@ +package io.ray.streaming.runtime.transfer.channel; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; + +/** This data structure contains offset used by streaming queue. */ +public class OffsetInfo implements Serializable { + + private long streamingMsgId; + + public OffsetInfo(long streamingMsgId) { + this.streamingMsgId = streamingMsgId; + } + + public long getStreamingMsgId() { + return streamingMsgId; + } + + public void setStreamingMsgId(long streamingMsgId) { + this.streamingMsgId = streamingMsgId; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("streamingMsgId", streamingMsgId).toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java new file mode 100644 index 00000000..f4d909ce --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.transfer.exception; + +import io.ray.streaming.runtime.transfer.DataReader; +import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import java.nio.ByteBuffer; + +/** + * when {@link DataReader#stop()} or {@link DataWriter#stop()} is called, this exception might be + * thrown in {@link DataReader#read(long)} and {@link DataWriter#write(ChannelId, ByteBuffer)}, + * which means the read/write operation is failed. + */ +public class ChannelInterruptException extends RuntimeException { + + public ChannelInterruptException() { + super(); + } + + public ChannelInterruptException(String message) { + super(message); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java new file mode 100644 index 00000000..7ea8d60f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java @@ -0,0 +1,37 @@ +package io.ray.streaming.runtime.transfer.message; + +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import java.nio.ByteBuffer; +import java.util.Map; + +public class BarrierMessage extends ChannelMessage { + + private final ByteBuffer data; + private final long checkpointId; + private final Map inputOffsets; + + public BarrierMessage( + long msgId, + long timestamp, + String channelId, + ByteBuffer data, + long checkpointId, + Map inputOffsets) { + super(msgId, timestamp, channelId); + this.data = data; + this.checkpointId = checkpointId; + this.inputOffsets = inputOffsets; + } + + public ByteBuffer getData() { + return data; + } + + public long getCheckpointId() { + return checkpointId; + } + + public Map getInputOffsets() { + return inputOffsets; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java new file mode 100644 index 00000000..6bfa4dca --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java @@ -0,0 +1,26 @@ +package io.ray.streaming.runtime.transfer.message; + +public class ChannelMessage { + + private final long msgId; + private final long timestamp; + private final String channelId; + + public ChannelMessage(long msgId, long timestamp, String channelId) { + this.msgId = msgId; + this.timestamp = timestamp; + this.channelId = channelId; + } + + public long getMsgId() { + return msgId; + } + + public long getTimestamp() { + return timestamp; + } + + public String getChannelId() { + return channelId; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java new file mode 100644 index 00000000..de873ae5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java @@ -0,0 +1,18 @@ +package io.ray.streaming.runtime.transfer.message; + +import java.nio.ByteBuffer; + +/** DataMessage represents data between upstream and downstream operators. */ +public class DataMessage extends ChannelMessage { + + private final ByteBuffer body; + + public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { + super(msgId, timestamp, channelId); + this.body = body; + } + + public ByteBuffer body() { + return body; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java new file mode 100644 index 00000000..b1da3f9c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java @@ -0,0 +1,56 @@ +package io.ray.streaming.runtime.util; + +import io.ray.streaming.runtime.context.ContextBackend; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handle exception for checkpoint state */ +public class CheckpointStateUtil { + + private static final Logger LOG = LoggerFactory.getLogger(CheckpointStateUtil.class); + + /** + * DO NOT ALLOW GET EXCEPTION WHEN LOADING CHECKPOINT + * + * @param checkpointState state backend + * @param cpKey checkpoint key + */ + public static byte[] get(ContextBackend checkpointState, String cpKey) { + byte[] val; + try { + val = checkpointState.get(cpKey); + } catch (Exception e) { + throw new CheckpointStateRuntimeException( + String.format("Failed to get %s from state backend.", cpKey), e); + } + return val; + } + + /** + * ALLOW PUT EXCEPTION WHEN SAVING CHECKPOINT + * + * @param checkpointState state backend + * @param key checkpoint key + * @param val checkpoint value + */ + public static void put(ContextBackend checkpointState, String key, byte[] val) { + try { + checkpointState.put(key, val); + } catch (Exception e) { + LOG.error("Failed to put key {} to state backend.", key, e); + } + } + + public static class CheckpointStateRuntimeException extends RuntimeException { + + public CheckpointStateRuntimeException() {} + + public CheckpointStateRuntimeException(String message) { + super(message); + } + + public CheckpointStateRuntimeException(String message, Throwable cause) { + super(message, cause); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CommonUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CommonUtils.java new file mode 100644 index 00000000..3ce8c1fe --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CommonUtils.java @@ -0,0 +1,12 @@ +package io.ray.streaming.runtime.util; + +import java.util.Map; + +/** Common tools. */ +public class CommonUtils { + + public static Map strMapToObjectMap(Map srcMap) { + Map destMap = (Map) srcMap; + return destMap; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java new file mode 100644 index 00000000..29ac29f4 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java @@ -0,0 +1,57 @@ +package io.ray.streaming.runtime.util; + +import io.ray.runtime.util.BinaryFileUtil; +import io.ray.runtime.util.JniUtils; +import java.lang.management.ManagementFactory; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EnvUtil { + + private static final Logger LOG = LoggerFactory.getLogger(EnvUtil.class); + + public static String getJvmPid() { + return ManagementFactory.getRuntimeMXBean().getName().split("@")[0]; + } + + public static String getHostName() { + String hostname = ""; + try { + hostname = InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + LOG.error("Error occurs while fetching local host.", e); + } + return hostname; + } + + public static void loadNativeLibraries() { + JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true); + JniUtils.loadLibrary("streaming_java"); + } + + /** + * Execute an external command. + * + * @return Whether the command succeeded. + */ + public static boolean executeCommand(List command, int waitTimeoutSeconds) { + try { + ProcessBuilder processBuilder = + new ProcessBuilder(command) + .redirectOutput(ProcessBuilder.Redirect.INHERIT) + .redirectError(ProcessBuilder.Redirect.INHERIT); + Process process = processBuilder.start(); + boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS); + if (!exit) { + process.destroyForcibly(); + } + return process.exitValue() == 0; + } catch (Exception e) { + throw new RuntimeException("Error executing command " + String.join(" ", command), e); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Platform.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Platform.java new file mode 100644 index 00000000..324e1ab9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Platform.java @@ -0,0 +1,87 @@ +package io.ray.streaming.runtime.util; + +import com.google.common.base.Preconditions; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import sun.misc.Unsafe; +import sun.nio.ch.DirectBuffer; + +/** Based on org.apache.spark.unsafe.Platform */ +public final class Platform { + + public static final Unsafe UNSAFE; + + static { + Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + throw new UnsupportedOperationException("Unsafe is not supported in this platform."); + } + UNSAFE = unsafe; + } + + // Access fields and constructors once and store them, for performance: + private static final Constructor DBB_CONSTRUCTOR; + private static final long BUFFER_ADDRESS_FIELD_OFFSET; + private static final long BUFFER_CAPACITY_FIELD_OFFSET; + + static { + try { + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + constructor.setAccessible(true); + DBB_CONSTRUCTOR = constructor; + Field addressField = Buffer.class.getDeclaredField("address"); + BUFFER_ADDRESS_FIELD_OFFSET = UNSAFE.objectFieldOffset(addressField); + Preconditions.checkArgument(BUFFER_ADDRESS_FIELD_OFFSET != 0); + Field capacityField = Buffer.class.getDeclaredField("capacity"); + BUFFER_CAPACITY_FIELD_OFFSET = UNSAFE.objectFieldOffset(capacityField); + Preconditions.checkArgument(BUFFER_CAPACITY_FIELD_OFFSET != 0); + } catch (ClassNotFoundException | NoSuchMethodException | NoSuchFieldException e) { + throw new IllegalStateException(e); + } + } + + private static final ThreadLocal localEmptyBuffer = + ThreadLocal.withInitial( + () -> { + try { + return (ByteBuffer) DBB_CONSTRUCTOR.newInstance(0, 0); + } catch (InstantiationException + | IllegalAccessException + | InvocationTargetException e) { + UNSAFE.throwException(e); + } + throw new IllegalStateException("unreachable"); + }); + + /** Wrap a buffer [address, address + size) as a DirectByteBuffer. */ + public static ByteBuffer wrapDirectBuffer(long address, int size) { + ByteBuffer buffer = localEmptyBuffer.get().duplicate(); + UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address); + UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size); + buffer.clear(); + return buffer; + } + + /** Wrap a buffer [address, address + size) into provided buffer. */ + public static void wrapDirectBuffer(ByteBuffer buffer, long address, int size) { + UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address); + UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size); + buffer.clear(); + } + + /** + * @param buffer a DirectBuffer backed by off-heap memory + * @return address of off-heap memory + */ + public static long getAddress(ByteBuffer buffer) { + return ((DirectBuffer) buffer).address(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/RayUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/RayUtils.java new file mode 100644 index 00000000..b3243d69 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/RayUtils.java @@ -0,0 +1,65 @@ +package io.ray.streaming.runtime.util; + +import io.ray.api.Ray; +import io.ray.api.id.UniqueId; +import io.ray.api.runtimecontext.NodeInfo; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** RayUtils is the utility class to access ray runtime api. */ +public class RayUtils { + + /** + * Get all node info from GCS + * + * @return node info list + */ + public static List getAllNodeInfo() { + if (Ray.getRuntimeContext().isSingleProcess()) { + // only for single process(for unit test) + return mockContainerResources(); + } + return Ray.getRuntimeContext().getAllNodeInfo(); + } + + /** + * Get all alive node info map + * + * @return node info map, key is unique node id , value is node info + */ + public static Map getAliveNodeInfoMap() { + return getAllNodeInfo().stream() + .filter(nodeInfo -> nodeInfo.isAlive) + .collect(Collectors.toMap(nodeInfo -> nodeInfo.nodeId, nodeInfo -> nodeInfo)); + } + + private static List mockContainerResources() { + List nodeInfos = new LinkedList<>(); + + for (int i = 1; i <= 5; i++) { + Map resources = new HashMap<>(); + resources.put("CPU", (double) i); + resources.put("MEM", 16.0); + + byte[] nodeIdBytes = new byte[UniqueId.LENGTH]; + for (int byteIndex = 0; byteIndex < UniqueId.LENGTH; ++byteIndex) { + nodeIdBytes[byteIndex] = String.valueOf(i).getBytes()[0]; + } + NodeInfo nodeInfo = + new NodeInfo( + new UniqueId(nodeIdBytes), + "localhost" + i, + "localhost" + i, + -1, + "", + "", + true, + resources); + nodeInfos.add(nodeInfo); + } + return nodeInfos; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java new file mode 100644 index 00000000..13a75f8e --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java @@ -0,0 +1,86 @@ +package io.ray.streaming.runtime.util; + +import com.google.common.base.Preconditions; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; + +@SuppressWarnings("UnstableApiUsage") +public class ReflectionUtils { + + public static Method findMethod(Class cls, String methodName) { + List methods = findMethods(cls, methodName); + Preconditions.checkArgument(methods.size() == 1); + return methods.get(0); + } + + /** + * For covariant return type, return the most specific method. + * + * @return all methods named by {@code methodName}, + */ + public static List findMethods(Class cls, String methodName) { + List> classes = new ArrayList<>(); + Class clazz = cls; + while (clazz != null) { + classes.add(clazz); + clazz = clazz.getSuperclass(); + } + classes.addAll(getAllInterfaces(cls)); + if (classes.indexOf(Object.class) == -1) { + classes.add(Object.class); + } + + LinkedHashMap>, Method> methods = new LinkedHashMap<>(); + for (Class superClass : classes) { + for (Method m : superClass.getDeclaredMethods()) { + if (m.getName().equals(methodName)) { + List> params = Arrays.asList(m.getParameterTypes()); + Method method = methods.get(params); + if (method == null) { + methods.put(params, m); + } else { + // for covariant return type, use the most specific method + if (method.getReturnType().isAssignableFrom(m.getReturnType())) { + methods.put(params, m); + } + } + } + } + } + return new ArrayList<>(methods.values()); + } + + /** + * Gets a List of all interfaces implemented by the given class and its superclasses. + * + *

The order is determined by looking through each interface in turn as declared in the source + * file and following its hierarchy up. + */ + public static List> getAllInterfaces(Class cls) { + if (cls == null) { + return null; + } + + LinkedHashSet> interfacesFound = new LinkedHashSet<>(); + getAllInterfaces(cls, interfacesFound); + return new ArrayList<>(interfacesFound); + } + + private static void getAllInterfaces(Class cls, LinkedHashSet> interfacesFound) { + while (cls != null) { + Class[] interfaces = cls.getInterfaces(); + for (Class anInterface : interfaces) { + if (!interfacesFound.contains(anInterface)) { + interfacesFound.add(anInterface); + getAllInterfaces(anInterface, interfacesFound); + } + } + + cls = cls.getSuperclass(); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java new file mode 100644 index 00000000..b00b6ee9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java @@ -0,0 +1,202 @@ +package io.ray.streaming.runtime.util; + +import com.sun.management.OperatingSystemMXBean; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.ContainerId; +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Resource Utility collects current OS and JVM resource usage information */ +public class ResourceUtil { + + public static final Logger LOG = LoggerFactory.getLogger(ResourceUtil.class); + + /** + * Refer to: + * https://docs.oracle.com/javase/8/docs/jre/api/management/extension/com/sun/management/OperatingSystemMXBean.html + */ + private static OperatingSystemMXBean osmxb = + (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean(); + + /** Log current jvm process's memory detail */ + public static void logProcessMemoryDetail() { + int mb = 1024 * 1024; + + // Getting the runtime reference from system + Runtime runtime = Runtime.getRuntime(); + + StringBuilder sb = new StringBuilder(32); + + sb.append("used memory: ") + .append((runtime.totalMemory() - runtime.freeMemory()) / mb) + .append(", free memory: ") + .append(runtime.freeMemory() / mb) + .append(", total memory: ") + .append(runtime.totalMemory() / mb) + .append(", max memory: ") + .append(runtime.maxMemory() / mb); + + if (LOG.isInfoEnabled()) { + LOG.info(sb.toString()); + } + } + + /** + * @return jvm heap usage ratio. note that one of the survivor space is not include in total + * memory while calculating this ratio. + */ + public static double getJvmHeapUsageRatio() { + Runtime runtime = Runtime.getRuntime(); + return (runtime.totalMemory() - runtime.freeMemory()) * 1.0 / runtime.maxMemory(); + } + + /** + * @return jvm heap usage(in bytes). note that this value doesn't include one of the survivor + * space. + */ + public static long getJvmHeapUsageInBytes() { + Runtime runtime = Runtime.getRuntime(); + return runtime.totalMemory() - runtime.freeMemory(); + } + + /** Returns the total amount of physical memory in bytes. */ + public static long getSystemTotalMemory() { + return osmxb.getTotalPhysicalMemorySize(); + } + + /** Returns the used system physical memory in bytes */ + public static long getSystemMemoryUsage() { + long totalMemory = osmxb.getTotalPhysicalMemorySize(); + long freeMemory = osmxb.getFreePhysicalMemorySize(); + return totalMemory - freeMemory; + } + + /** Returns the ratio of used system physical memory. This value is a double in the [0.0,1.0] */ + public static double getSystemMemoryUsageRatio() { + double totalMemory = osmxb.getTotalPhysicalMemorySize(); + double freeMemory = osmxb.getFreePhysicalMemorySize(); + double ratio = freeMemory / totalMemory; + return 1 - ratio; + } + + /** Returns the cpu load for current jvm process. This value is a double in the [0.0,1.0] */ + public static double getProcessCpuUsage() { + return osmxb.getProcessCpuLoad(); + } + + /** + * @return the system cpu usage. This value is a double in the [0.0,1.0] We will try to use `vsar` + * to get cpu usage by default, and use MXBean if any exception raised. + */ + public static double getSystemCpuUsage() { + double cpuUsage = 0.0; + try { + cpuUsage = getSystemCpuUtilByVsar(); + } catch (Exception e) { + cpuUsage = getSystemCpuUtilByMXBean(); + } + return cpuUsage; + } + + /** + * @return the "recent cpu usage" for the whole system. This value is a double in the [0.0,1.0] + * interval. A value of 0.0 means that all CPUs were idle during the recent period of time + * observed, while a value of 1.0 means that all CPUs were actively running 100% of the time + * during the recent period being observed + */ + public static double getSystemCpuUtilByMXBean() { + return osmxb.getSystemCpuLoad(); + } + + /** Get system cpu util by vsar */ + public static double getSystemCpuUtilByVsar() throws Exception { + double cpuUsageFromVsar = 0.0; + String[] vsarCpuCommand = {"/bin/sh", "-c", "vsar --check --cpu -s util"}; + try { + Process proc = Runtime.getRuntime().exec(vsarCpuCommand); + BufferedInputStream bis = new BufferedInputStream(proc.getInputStream()); + BufferedReader br = new BufferedReader(new InputStreamReader(bis)); + String line; + List processPidList = new ArrayList<>(); + while ((line = br.readLine()) != null) { + processPidList.add(line); + } + if (!processPidList.isEmpty()) { + String[] split = processPidList.get(0).split("="); + cpuUsageFromVsar = Double.parseDouble(split[1]) / 100.0D; + } else { + throw new IOException("Vsar check cpu usage failed, maybe vsar is not installed."); + } + } catch (Exception e) { + LOG.warn("Failed to get cpu usage by vsar.", e); + throw e; + } + return cpuUsageFromVsar; + } + + /** Returns the system load average for the last minute */ + public static double getSystemLoadAverage() { + return osmxb.getSystemLoadAverage(); + } + + /** Returns system cpu cores num */ + public static int getCpuCores() { + return osmxb.getAvailableProcessors(); + } + + /** + * Get containers by hostname of address + * + * @param containers container list + * @param containerHosts container hostname or address set + * @return matched containers + */ + public static List getContainersByHostname( + List containers, Collection containerHosts) { + + return containers.stream() + .filter( + container -> + containerHosts.contains(container.getHostname()) + || containerHosts.contains(container.getAddress())) + .collect(Collectors.toList()); + } + + /** + * Get container by hostname + * + * @param hostName container hostname + * @return container + */ + public static Optional getContainerByHostname( + List containers, String hostName) { + return containers.stream() + .filter( + container -> + container.getHostname().equals(hostName) || container.getAddress().equals(hostName)) + .findFirst(); + } + + /** + * Get container by id + * + * @param containerID container id + * @return container + */ + public static Optional getContainerById( + List containers, ContainerId containerID) { + return containers.stream() + .filter(container -> container.getId().equals(containerID)) + .findFirst(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java new file mode 100644 index 00000000..df435e5a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java @@ -0,0 +1,14 @@ +package io.ray.streaming.runtime.util; + +import io.ray.runtime.serializer.FstSerializer; + +public class Serializer { + + public static byte[] encode(Object obj) { + return FstSerializer.encode(obj); + } + + public static T decode(byte[] bytes) { + return FstSerializer.decode(bytes); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java new file mode 100644 index 00000000..15200c65 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java @@ -0,0 +1,336 @@ +package io.ray.streaming.runtime.worker; + +import io.ray.api.Ray; +import io.ray.streaming.runtime.config.StreamingWorkerConfig; +import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.ContextBackendFactory; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.processor.OneInputProcessor; +import io.ray.streaming.runtime.core.processor.ProcessBuilder; +import io.ray.streaming.runtime.core.processor.SourceProcessor; +import io.ray.streaming.runtime.core.processor.StreamProcessor; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.rpc.RemoteCallMaster; +import io.ray.streaming.runtime.transfer.TransferHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; +import io.ray.streaming.runtime.util.CheckpointStateUtil; +import io.ray.streaming.runtime.util.EnvUtil; +import io.ray.streaming.runtime.util.Serializer; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; +import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; +import io.ray.streaming.runtime.worker.tasks.StreamTask; +import java.io.Serializable; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The streaming worker implementation class, it is ray actor. JobWorker is created by {@link + * JobMaster} through ray api, and JobMaster communicates with JobWorker through Ray.call(). + * + *

The JobWorker is responsible for creating tasks and defines the methods of communication + * between workers. + */ +public class JobWorker implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(JobWorker.class); + + // special flag to indicate this actor not ready + private static final byte[] NOT_READY_FLAG = new byte[4]; + + static { + EnvUtil.loadNativeLibraries(); + } + + public final Object initialStateChangeLock = new Object(); + /** isRecreate=true means this worker is initialized more than once after actor created. */ + public AtomicBoolean isRecreate = new AtomicBoolean(false); + + public ContextBackend contextBackend; + private JobWorkerContext workerContext; + private ExecutionVertex executionVertex; + private StreamingWorkerConfig workerConfig; + /** The while-loop thread to read message, process message, and write results */ + private StreamTask task; + /** transferHandler handles messages by ray direct call */ + private TransferHandler transferHandler; + /** + * A flag to avoid duplicated rollback. Becomes true after requesting rollback, set to false when + * finish rollback. + */ + private boolean isNeedRollback = false; + + private int rollbackCount = 0; + + public JobWorker(ExecutionVertex executionVertex) { + LOG.info("Creating job worker."); + + // TODO: the following 3 lines is duplicated with that in init(), try to optimise it later. + this.executionVertex = executionVertex; + this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); + this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); + + LOG.info( + "Ray.getRuntimeContext().wasCurrentActorRestarted()={}", + Ray.getRuntimeContext().wasCurrentActorRestarted()); + if (!Ray.getRuntimeContext().wasCurrentActorRestarted()) { + saveContext(); + LOG.info("Job worker is fresh started, init success."); + return; + } + + LOG.info("Begin load job worker checkpoint state."); + + byte[] bytes = CheckpointStateUtil.get(contextBackend, getJobWorkerContextKey()); + if (bytes != null) { + JobWorkerContext context = Serializer.decode(bytes); + LOG.info( + "Worker recover from checkpoint state, byte len={}, context={}.", bytes.length, context); + init(context); + requestRollback("LoadCheckpoint request rollback in new actor."); + } else { + LOG.error( + "Worker is reconstructed, but can't load checkpoint. " + + "Check whether you checkpoint state is reliable. Current checkpoint state is {}.", + contextBackend.getClass().getName()); + } + } + + public synchronized void saveContext() { + byte[] contextBytes = Serializer.encode(workerContext); + String key = getJobWorkerContextKey(); + LOG.info( + "Saving context, worker context={}, serialized byte length={}, key={}.", + workerContext, + contextBytes.length, + key); + CheckpointStateUtil.put(contextBackend, key, contextBytes); + } + + /** Initialize JobWorker and data communication pipeline. */ + public Boolean init(JobWorkerContext workerContext) { + // IMPORTANT: some test cases depends on this log to find workers' pid, + // be careful when changing this log. + LOG.info( + "Initiating job worker: {}. Worker context is: {}, pid={}.", + workerContext.getWorkerName(), + workerContext, + EnvUtil.getJvmPid()); + + this.workerContext = workerContext; + this.executionVertex = workerContext.getExecutionVertex(); + this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); + // init state backend + this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); + + LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName()); + saveContext(); + return true; + } + + /** + * Start worker's stream tasks with specific checkpoint ID. + * + * @return a {@link CallResult} with {@link ChannelRecoverInfo}, contains {@link + * ChannelCreationStatus} of each input queue. + */ + public CallResult rollback(Long checkpointId, Long startRollbackTs) { + synchronized (initialStateChangeLock) { + if (task != null + && task.isAlive() + && checkpointId == task.lastCheckpointId + && task.isInitialState) { + return CallResult.skipped("Task is already in initial state, skip this rollback."); + } + } + long remoteCallCost = System.currentTimeMillis() - startRollbackTs; + + LOG.info( + "Start rollback[{}], checkpoint is {}, remote call cost {}ms.", + executionVertex.getExecutionJobVertexName(), + checkpointId, + remoteCallCost); + + rollbackCount++; + if (rollbackCount > 1) { + isRecreate.set(true); + } + + try { + // Init transfer + TransferChannelType channelType = workerConfig.transferConfig.channelType(); + if (TransferChannelType.NATIVE_CHANNEL == channelType) { + transferHandler = new TransferHandler(); + } + + if (task != null) { + // make sure the task is closed + task.close(); + task = null; + } + + // create stream task + task = createStreamTask(checkpointId); + ChannelRecoverInfo channelRecoverInfo = task.recover(isRecreate.get()); + isNeedRollback = false; + + LOG.info( + "Rollback job worker success, checkpoint is {}, channelRecoverInfo is {}.", + checkpointId, + channelRecoverInfo); + + return CallResult.success(channelRecoverInfo); + } catch (Exception e) { + LOG.error("Rollback job worker has exception.", e); + return CallResult.fail(ExceptionUtils.getStackTrace(e)); + } + } + + /** Create tasks based on the processor corresponding of the operator. */ + private StreamTask createStreamTask(long checkpointId) { + StreamTask task; + StreamProcessor streamProcessor = + ProcessBuilder.buildProcessor(executionVertex.getStreamOperator()); + LOG.debug("Stream processor created: {}.", streamProcessor); + + if (streamProcessor instanceof SourceProcessor) { + task = new SourceStreamTask(streamProcessor, this, checkpointId); + } else if (streamProcessor instanceof OneInputProcessor) { + task = new OneInputStreamTask(streamProcessor, this, checkpointId); + } else { + throw new RuntimeException("Unsupported processor type:" + streamProcessor); + } + LOG.info("Stream task created: {}.", task); + return task; + } + + // ---------------------------------------------------------------------- + // Checkpoint + // ---------------------------------------------------------------------- + + /** Trigger source job worker checkpoint */ + public Boolean triggerCheckpoint(Long barrierId) { + LOG.info("Receive trigger, barrierId is {}.", barrierId); + if (task != null) { + return task.triggerCheckpoint(barrierId); + } + return false; + } + + public Boolean notifyCheckpointTimeout(Long checkpointId) { + LOG.info("Notify checkpoint timeout, checkpoint id is {}.", checkpointId); + if (task != null) { + task.notifyCheckpointTimeout(checkpointId); + } + return true; + } + + public Boolean clearExpiredCheckpoint(Long expiredStateCpId, Long expiredQueueCpId) { + LOG.info( + "Clear expired checkpoint state, checkpoint id is {}; " + + "Clear expired queue msg, checkpoint id is {}", + expiredStateCpId, + expiredQueueCpId); + if (task != null) { + if (expiredStateCpId > 0) { + task.clearExpiredCpState(expiredStateCpId); + } + task.clearExpiredQueueMsg(expiredQueueCpId); + } + return true; + } + + // ---------------------------------------------------------------------- + // Failover + // ---------------------------------------------------------------------- + public void requestRollback(String exceptionMsg) { + LOG.info("Request rollback."); + isNeedRollback = true; + isRecreate.set(true); + boolean requestRet = + RemoteCallMaster.requestJobWorkerRollback( + workerContext.getMaster(), + new WorkerRollbackRequest( + workerContext.getWorkerActorId(), + exceptionMsg, + EnvUtil.getHostName(), + EnvUtil.getJvmPid())); + if (!requestRet) { + LOG.warn("Job worker request rollback failed! exceptionMsg={}.", exceptionMsg); + } + } + + public Boolean checkIfNeedRollback(Long startCallTs) { + // No save checkpoint in this query. + long remoteCallCost = System.currentTimeMillis() - startCallTs; + LOG.info( + "Finished checking if need to rollback with result: {}, rpc delay={}ms.", + isNeedRollback, + remoteCallCost); + return isNeedRollback; + } + + public StreamingWorkerConfig getWorkerConfig() { + return workerConfig; + } + + public JobWorkerContext getWorkerContext() { + return workerContext; + } + + public ExecutionVertex getExecutionVertex() { + return executionVertex; + } + + public StreamTask getTask() { + return task; + } + + private String getJobWorkerContextKey() { + return workerConfig.checkpointConfig.jobWorkerContextCpPrefixKey() + + workerConfig.commonConfig.jobName() + + "_" + + executionVertex.getExecutionVertexId(); + } + + /** Used by upstream streaming queue to send data to this actor */ + public void onReaderMessage(byte[] buffer) { + if (transferHandler != null) { + transferHandler.onReaderMessage(buffer); + } + } + + /** + * Used by upstream streaming queue to send data to this actor and receive result from this actor + */ + public byte[] onReaderMessageSync(byte[] buffer) { + if (transferHandler == null) { + return NOT_READY_FLAG; + } + return transferHandler.onReaderMessageSync(buffer); + } + + /** Used by downstream streaming queue to send data to this actor */ + public void onWriterMessage(byte[] buffer) { + if (transferHandler != null) { + transferHandler.onWriterMessage(buffer); + } + } + + /** + * Used by downstream streaming queue to send data to this actor and receive result from this + * actor + */ + public byte[] onWriterMessageSync(byte[] buffer) { + if (transferHandler == null) { + return NOT_READY_FLAG; + } + return transferHandler.onWriterMessageSync(buffer); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java new file mode 100644 index 00000000..d92e95bb --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java @@ -0,0 +1,85 @@ +package io.ray.streaming.runtime.worker.context; + +import com.google.common.base.MoreObjects; +import com.google.protobuf.ByteString; +import io.ray.api.ActorHandle; +import io.ray.api.id.ActorId; +import io.ray.runtime.actor.NativeActorHandle; +import io.ray.streaming.runtime.config.global.CommonConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.python.GraphPbBuilder; +import java.io.Serializable; +import java.util.Map; + +/** Job worker context of java type. */ +public class JobWorkerContext implements Serializable { + + /** JobMaster actor. */ + private ActorHandle master; + + /** Worker's vertex info. */ + private ExecutionVertex executionVertex; + + public JobWorkerContext(ActorHandle master, ExecutionVertex executionVertex) { + this.master = master; + this.executionVertex = executionVertex; + } + + public ActorId getWorkerActorId() { + return executionVertex.getWorkerActorId(); + } + + public int getWorkerId() { + return executionVertex.getExecutionVertexId(); + } + + public String getWorkerName() { + return executionVertex.getExecutionVertexName(); + } + + public Map getConfig() { + return executionVertex.getWorkerConfig(); + } + + public ActorHandle getMaster() { + return master; + } + + public ExecutionVertex getExecutionVertex() { + return executionVertex; + } + + public Map getConf() { + return getExecutionVertex().getWorkerConfig(); + } + + public String getJobName() { + return getConf().get(CommonConfig.JOB_NAME); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("workerId", getWorkerId()) + .add("workerName", getWorkerName()) + .add("config", getConfig()) + .toString(); + } + + public byte[] getPythonWorkerContextBytes() { + // create python worker context + RemoteCall.ExecutionVertexContext executionVertexContext = + new GraphPbBuilder().buildExecutionVertexContext(executionVertex); + + byte[] contextBytes = + RemoteCall.PythonJobWorkerContext.newBuilder() + .setMasterActor(ByteString.copyFrom((((NativeActorHandle) (master)).toBytes()))) + .setExecutionVertexContext(executionVertexContext) + .build() + .toByteArray(); + + return contextBytes; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/StreamingRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/StreamingRuntimeContext.java new file mode 100644 index 00000000..73fe4df0 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/StreamingRuntimeContext.java @@ -0,0 +1,119 @@ +package io.ray.streaming.runtime.worker.context; + +import com.google.common.base.Preconditions; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.OperatorStateBackend; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import java.util.Map; + +/** Use Ray to implement RuntimeContext. */ +public class StreamingRuntimeContext implements RuntimeContext { + + /** Backend for keyed state. This might be empty if we're not on a keyed stream. */ + protected transient KeyStateBackend keyStateBackend; + /** Backend for operator state. This might be empty */ + protected transient OperatorStateBackend operatorStateBackend; + + private int taskId; + private int taskIndex; + private int parallelism; + private Long checkpointId; + private Map config; + + public StreamingRuntimeContext( + ExecutionVertex executionVertex, Map config, int parallelism) { + this.taskId = executionVertex.getExecutionVertexId(); + this.config = config; + this.taskIndex = executionVertex.getExecutionVertexIndex(); + this.parallelism = parallelism; + } + + @Override + public int getTaskId() { + return taskId; + } + + @Override + public int getTaskIndex() { + return taskIndex; + } + + @Override + public int getParallelism() { + return parallelism; + } + + @Override + public Map getConfig() { + return config; + } + + @Override + public Map getJobConfig() { + return config; + } + + @Override + public Long getCheckpointId() { + return checkpointId; + } + + @Override + public void setCheckpointId(long checkpointId) { + if (this.keyStateBackend != null) { + this.keyStateBackend.setCheckpointId(checkpointId); + } + if (this.operatorStateBackend != null) { + this.operatorStateBackend.setCheckpointId(checkpointId); + } + this.checkpointId = checkpointId; + } + + @Override + public void setCurrentKey(Object key) { + this.keyStateBackend.setCurrentKey(key); + } + + @Override + public KeyStateBackend getKeyStateBackend() { + return keyStateBackend; + } + + @Override + public void setKeyStateBackend(KeyStateBackend keyStateBackend) { + this.keyStateBackend = keyStateBackend; + } + + @Override + public ValueState getValueState(ValueStateDescriptor stateDescriptor) { + stateSanityCheck(stateDescriptor, this.keyStateBackend); + return this.keyStateBackend.getValueState(stateDescriptor); + } + + @Override + public ListState getListState(ListStateDescriptor stateDescriptor) { + stateSanityCheck(stateDescriptor, this.keyStateBackend); + return this.keyStateBackend.getListState(stateDescriptor); + } + + @Override + public MapState getMapState(MapStateDescriptor stateDescriptor) { + stateSanityCheck(stateDescriptor, this.keyStateBackend); + return this.keyStateBackend.getMapState(stateDescriptor); + } + + protected void stateSanityCheck( + AbstractStateDescriptor stateDescriptor, AbstractKeyStateBackend backend) { + Preconditions.checkNotNull(stateDescriptor, "The state properties must not be null"); + Preconditions.checkNotNull(backend, "backend must not be null"); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java new file mode 100644 index 00000000..9ac14ad0 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -0,0 +1,101 @@ +package io.ray.streaming.runtime.worker.tasks; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.serialization.CrossLangSerializer; +import io.ray.streaming.runtime.serialization.JavaSerializer; +import io.ray.streaming.runtime.serialization.Serializer; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; +import io.ray.streaming.runtime.transfer.message.BarrierMessage; +import io.ray.streaming.runtime.transfer.message.ChannelMessage; +import io.ray.streaming.runtime.transfer.message.DataMessage; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.Map; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class InputStreamTask extends StreamTask { + + private static final Logger LOG = LoggerFactory.getLogger(InputStreamTask.class); + + private final io.ray.streaming.runtime.serialization.Serializer javaSerializer; + private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer; + private final long readTimeoutMillis; + + public InputStreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { + super(processor, jobWorker, lastCheckpointId); + readTimeoutMillis = jobWorker.getWorkerConfig().transferConfig.readerTimerIntervalMs(); + javaSerializer = new JavaSerializer(); + crossLangSerializer = new CrossLangSerializer(); + } + + @Override + protected void init() {} + + @Override + public void run() { + try { + while (running) { + ChannelMessage item; + + // reader.read() will change the consumer state once it got an item. This lock is to + // ensure worker can get correct isInitialState value in exactly-once-mode's rollback. + synchronized (jobWorker.initialStateChangeLock) { + item = reader.read(readTimeoutMillis); + if (item != null) { + isInitialState = false; + } else { + continue; + } + } + + if (item instanceof DataMessage) { + DataMessage dataMessage = (DataMessage) item; + byte[] bytes = new byte[dataMessage.body().remaining() - 1]; + byte typeId = dataMessage.body().get(); + dataMessage.body().get(bytes); + Object obj; + if (typeId == Serializer.JAVA_TYPE_ID) { + obj = javaSerializer.deserialize(bytes); + } else { + obj = crossLangSerializer.deserialize(bytes); + } + processor.process(obj); + } else if (item instanceof BarrierMessage) { + final BarrierMessage queueBarrier = (BarrierMessage) item; + byte[] barrierData = new byte[queueBarrier.getData().remaining()]; + queueBarrier.getData().get(barrierData); + RemoteCall.Barrier barrierPb = RemoteCall.Barrier.parseFrom(barrierData); + final long checkpointId = barrierPb.getId(); + LOG.info( + "Start to do checkpoint {}, worker name is {}.", + checkpointId, + jobWorker.getWorkerContext().getWorkerName()); + + final Map inputPoints = queueBarrier.getInputOffsets(); + doCheckpoint(checkpointId, inputPoints); + LOG.info("Do checkpoint {} success.", checkpointId); + } + } + } catch (Throwable throwable) { + if (throwable instanceof ChannelInterruptException + || ExceptionUtils.getRootCause(throwable) instanceof ChannelInterruptException) { + LOG.info("queue has stopped."); + } else { + // error occurred, need to rollback + LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, throwable); + requestRollback(ExceptionUtils.getStackTrace(throwable)); + } + } + LOG.info("Input stream task thread exit."); + stopped = true; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("processor", processor).toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java new file mode 100644 index 00000000..483f88ae --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java @@ -0,0 +1,12 @@ +package io.ray.streaming.runtime.worker.tasks; + +import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.worker.JobWorker; + +/** Input stream task with 1 input. Such as: map operator. */ +public class OneInputStreamTask extends InputStreamTask { + + public OneInputStreamTask(Processor inputProcessor, JobWorker jobWorker, long lastCheckpointId) { + super(inputProcessor, jobWorker, lastCheckpointId); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java new file mode 100644 index 00000000..d5041a97 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java @@ -0,0 +1,87 @@ +package io.ray.streaming.runtime.worker.tasks; + +import io.ray.streaming.operator.SourceOperator; +import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.core.processor.SourceProcessor; +import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SourceStreamTask extends StreamTask { + + private static final Logger LOG = LoggerFactory.getLogger(SourceStreamTask.class); + + private final SourceProcessor sourceProcessor; + + /** The pending barrier ID to be triggered. */ + private final AtomicReference pendingBarrier = new AtomicReference<>(); + + private long lastCheckpointId = 0; + + /** + * SourceStreamTask for executing a {@link SourceOperator}. It is responsible for running the + * corresponding source operator. + */ + public SourceStreamTask(Processor sourceProcessor, JobWorker jobWorker, long lastCheckpointId) { + super(sourceProcessor, jobWorker, lastCheckpointId); + this.sourceProcessor = (SourceProcessor) processor; + } + + @Override + protected void init() {} + + @Override + public void run() { + LOG.info("Source stream task thread start."); + Long barrierId; + try { + while (running) { + isInitialState = false; + + // check checkpoint + barrierId = pendingBarrier.get(); + if (barrierId != null) { + // Important: because cp maybe timeout, master will use the old checkpoint id again + if (pendingBarrier.compareAndSet(barrierId, null)) { + // source fetcher only have outputPoints + LOG.info( + "Start to do checkpoint {}, worker name is {}.", + barrierId, + jobWorker.getWorkerContext().getWorkerName()); + + doCheckpoint(barrierId, null); + + LOG.info("Finish to do checkpoint {}.", barrierId); + } else { + // pendingCheckpointId has modify, should not happen + LOG.warn( + "Pending checkpointId modify unexpected, expect={}, now={}.", + barrierId, + pendingBarrier.get()); + } + } + + sourceProcessor.fetch(); + } + } catch (Throwable e) { + if (e instanceof ChannelInterruptException + || ExceptionUtils.getRootCause(e) instanceof ChannelInterruptException) { + LOG.info("queue has stopped."); + } else { + // occur error, need to rollback + LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, e); + requestRollback(ExceptionUtils.getStackTrace(e)); + } + } + + LOG.info("Source stream task thread exit."); + } + + @Override + public boolean triggerCheckpoint(Long barrierId) { + return pendingBarrier.compareAndSet(null, barrierId); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java new file mode 100644 index 00000000..6acf016e --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -0,0 +1,376 @@ +package io.ray.streaming.runtime.worker.tasks; + +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.streaming.api.collector.Collector; +import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.runtime.config.worker.WorkerInternalConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.OperatorCheckpointInfo; +import io.ray.streaming.runtime.core.collector.OutputCollector; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.rpc.RemoteCallMaster; +import io.ray.streaming.runtime.transfer.DataReader; +import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.util.CheckpointStateUtil; +import io.ray.streaming.runtime.util.Serializer; +import io.ray.streaming.runtime.worker.JobWorker; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import io.ray.streaming.runtime.worker.context.StreamingRuntimeContext; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link StreamTask} is a while-loop thread to read message, process message, and send result + * messages to downstream operators + */ +public abstract class StreamTask implements Runnable { + + private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); + private final ContextBackend checkpointState; + public volatile boolean isInitialState = true; + public long lastCheckpointId; + protected Processor processor; + protected JobWorker jobWorker; + protected DataReader reader; + protected DataWriter writer; + protected volatile boolean running = true; + protected volatile boolean stopped = false; + List collectors = new ArrayList<>(); + private Set outdatedCheckpoints = new HashSet<>(); + private Thread thread; + + protected StreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { + this.processor = processor; + this.jobWorker = jobWorker; + this.checkpointState = jobWorker.contextBackend; + this.lastCheckpointId = lastCheckpointId; + + this.thread = + new Thread( + Ray.wrapRunnable(this), this.getClass().getName() + "-" + System.currentTimeMillis()); + this.thread.setDaemon(true); + } + + public ChannelRecoverInfo recover(boolean isRecover) { + + if (isRecover) { + LOG.info("Stream task begin recover."); + } else { + LOG.info("Stream task first start begin."); + } + prepareTask(isRecover); + + // start runner + ChannelRecoverInfo recoverInfo = new ChannelRecoverInfo(new HashMap<>()); + if (reader != null) { + recoverInfo = reader.getQueueRecoverInfo(); + } + + thread.setUncaughtExceptionHandler( + (t, e) -> LOG.error("Uncaught exception in runner thread.", e)); + LOG.info("Start stream task: {}.", this.getClass().getSimpleName()); + thread.start(); + + if (isRecover) { + LOG.info("Stream task recover end."); + } else { + LOG.info("Stream task first start finished."); + } + + return recoverInfo; + } + + /** + * Load checkpoint and build upstream and downstream data transmission channels according to + * {@link ExecutionVertex}. + */ + private void prepareTask(boolean isRecreate) { + LOG.info("Preparing stream task, isRecreate={}.", isRecreate); + ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); + + // set vertex info into config for native using + jobWorker + .getWorkerConfig() + .workerInternalConfig + .setProperty( + WorkerInternalConfig.WORKER_NAME_INTERNAL, executionVertex.getExecutionVertexName()); + jobWorker + .getWorkerConfig() + .workerInternalConfig + .setProperty( + WorkerInternalConfig.OP_NAME_INTERNAL, executionVertex.getExecutionJobVertexName()); + + OperatorCheckpointInfo operatorCheckpointInfo = new OperatorCheckpointInfo(); + byte[] bytes = null; + + // Fetch checkpoint from storage only in recreate mode not for new startup worker + // in rescaling or something like that. + if (isRecreate) { + String cpKey = genOpCheckpointKey(lastCheckpointId); + LOG.info( + "Getting task checkpoints from state, cpKey={}, checkpointId={}.", + cpKey, + lastCheckpointId); + bytes = CheckpointStateUtil.get(checkpointState, cpKey); + if (bytes == null) { + String msg = String.format("Task recover failed, checkpoint is null! cpKey=%s", cpKey); + throw new RuntimeException(msg); + } + } + + // when use memory state, if actor throw exception, will miss state + if (bytes != null) { + operatorCheckpointInfo = Serializer.decode(bytes); + processor.loadCheckpoint(operatorCheckpointInfo.processorCheckpoint); + LOG.info( + "Stream task recover from checkpoint state, checkpoint bytes len={}, checkpointInfo={}.", + bytes.length, + operatorCheckpointInfo); + } + + // writer + if (!executionVertex.getOutputEdges().isEmpty()) { + LOG.info( + "Register queue writer, channels={}, outputCheckpoints={}.", + executionVertex.getOutputChannelIdList(), + operatorCheckpointInfo.outputPoints); + writer = + new DataWriter( + executionVertex.getOutputChannelIdList(), + executionVertex.getOutputActorList(), + operatorCheckpointInfo.outputPoints, + jobWorker.getWorkerConfig()); + } + + // reader + if (!executionVertex.getInputEdges().isEmpty()) { + LOG.info( + "Register queue reader, channels={}, inputCheckpoints={}.", + executionVertex.getInputChannelIdList(), + operatorCheckpointInfo.inputPoints); + reader = + new DataReader( + executionVertex.getInputChannelIdList(), + executionVertex.getInputActorList(), + operatorCheckpointInfo.inputPoints, + jobWorker.getWorkerConfig()); + } + + openProcessor(); + + LOG.debug("Finished preparing stream task."); + } + + /** + * Create one collector for each distinct output operator(i.e. each {@link ExecutionJobVertex}) + */ + private void openProcessor() { + ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); + List outputEdges = executionVertex.getOutputEdges(); + + Map> opGroupedChannelId = new HashMap<>(); + Map> opGroupedActor = new HashMap<>(); + Map opPartitionMap = new HashMap<>(); + for (int i = 0; i < outputEdges.size(); ++i) { + ExecutionEdge edge = outputEdges.get(i); + String opName = edge.getTargetExecutionJobVertexName(); + if (!opPartitionMap.containsKey(opName)) { + opGroupedChannelId.put(opName, new ArrayList<>()); + opGroupedActor.put(opName, new ArrayList<>()); + } + opGroupedChannelId.get(opName).add(executionVertex.getOutputChannelIdList().get(i)); + opGroupedActor.get(opName).add(executionVertex.getOutputActorList().get(i)); + opPartitionMap.put(opName, edge.getPartition()); + } + opPartitionMap + .keySet() + .forEach( + opName -> { + collectors.add( + new OutputCollector( + writer, + opGroupedChannelId.get(opName), + opGroupedActor.get(opName), + opPartitionMap.get(opName))); + }); + + RuntimeContext runtimeContext = + new StreamingRuntimeContext( + executionVertex, + jobWorker.getWorkerConfig().configMap, + executionVertex.getParallelism()); + + processor.open(collectors, runtimeContext); + } + + /** Task initialization related work. */ + protected abstract void init() throws Exception; + + /** Close running tasks. */ + public void close() { + this.running = false; + if (thread.isAlive() && !Ray.getRuntimeContext().isSingleProcess()) { + // `Runtime.halt` is used because System.exist can't ensure the process killing. + Runtime.getRuntime().halt(0); + LOG.warn("runtime halt 0"); + System.exit(0); + } + LOG.info("Stream task close success."); + } + + // ---------------------------------------------------------------------- + // Checkpoint + // ---------------------------------------------------------------------- + + public boolean triggerCheckpoint(Long barrierId) { + throw new UnsupportedOperationException("Only source operator supports trigger checkpoints."); + } + + public void doCheckpoint(long checkpointId, Map inputPoints) { + Map outputPoints = null; + if (writer != null) { + outputPoints = writer.getOutputCheckpoints(); + RemoteCall.Barrier barrierPb = RemoteCall.Barrier.newBuilder().setId(checkpointId).build(); + ByteBuffer byteBuffer = ByteBuffer.wrap(barrierPb.toByteArray()); + byteBuffer.order(ByteOrder.nativeOrder()); + writer.broadcastBarrier(checkpointId, byteBuffer); + } + + LOG.info( + "Start do checkpoint, cp id={}, inputPoints={}, outputPoints={}.", + checkpointId, + inputPoints, + outputPoints); + + this.lastCheckpointId = checkpointId; + Serializable processorCheckpoint = processor.saveCheckpoint(); + + try { + OperatorCheckpointInfo opCpInfo = + new OperatorCheckpointInfo(inputPoints, outputPoints, processorCheckpoint, checkpointId); + saveCpStateAndReport(opCpInfo, checkpointId); + } catch (Exception e) { + // there will be exceptions when flush state to backend. + // we ignore the exception to prevent failover + LOG.error("Processor or op checkpoint exception.", e); + } + + LOG.info("Operator do checkpoint {} finish.", checkpointId); + } + + private void saveCpStateAndReport( + OperatorCheckpointInfo operatorCheckpointInfo, long checkpointId) { + saveCp(operatorCheckpointInfo, checkpointId); + reportCommit(checkpointId); + + LOG.info("Finish save cp state and report, checkpoint id is {}.", checkpointId); + } + + private void saveCp(OperatorCheckpointInfo operatorCheckpointInfo, long checkpointId) { + byte[] bytes = Serializer.encode(operatorCheckpointInfo); + String cpKey = genOpCheckpointKey(checkpointId); + LOG.info( + "Saving task checkpoint, cpKey={}, byte len={}, checkpointInfo={}.", + cpKey, + bytes.length, + operatorCheckpointInfo); + synchronized (checkpointState) { + if (outdatedCheckpoints.contains(checkpointId)) { + LOG.info("Outdated checkpoint, skip save checkpoint."); + outdatedCheckpoints.remove(checkpointId); + } else { + CheckpointStateUtil.put(checkpointState, cpKey, bytes); + } + } + } + + private void reportCommit(long checkpointId) { + final JobWorkerContext context = jobWorker.getWorkerContext(); + LOG.info("Report commit async, checkpoint id {}.", checkpointId); + RemoteCallMaster.reportJobWorkerCommitAsync( + context.getMaster(), new WorkerCommitReport(context.getWorkerActorId(), checkpointId)); + } + + public void notifyCheckpointTimeout(long checkpointId) { + String cpKey = genOpCheckpointKey(checkpointId); + try { + synchronized (checkpointState) { + if (checkpointState.exists(cpKey)) { + checkpointState.remove(cpKey); + } else { + outdatedCheckpoints.add(checkpointId); + } + } + } catch (Exception e) { + LOG.error("Notify checkpoint timeout failed, checkpointId is {}.", checkpointId, e); + } + } + + public void clearExpiredCpState(long checkpointId) { + String cpKey = genOpCheckpointKey(checkpointId); + try { + checkpointState.remove(cpKey); + } catch (Exception e) { + LOG.error("Failed to remove key {} from state backend.", cpKey, e); + } + } + + public void clearExpiredQueueMsg(long checkpointId) { + // get operator checkpoint + String cpKey = genOpCheckpointKey(checkpointId); + byte[] bytes; + try { + bytes = checkpointState.get(cpKey); + } catch (Exception e) { + LOG.error("Failed to get key {} from state backend.", cpKey, e); + return; + } + if (bytes != null) { + final OperatorCheckpointInfo operatorCheckpointInfo = Serializer.decode(bytes); + long cpId = operatorCheckpointInfo.checkpointId; + if (writer != null) { + writer.clearCheckpoint(cpId); + } + } + } + + public String genOpCheckpointKey(long checkpointId) { + // TODO: need to support job restart and actorId changed + final JobWorkerContext context = jobWorker.getWorkerContext(); + return jobWorker.getWorkerConfig().checkpointConfig.jobWorkerOpCpPrefixKey() + + context.getJobName() + + "_" + + context.getWorkerName() + + "_" + + checkpointId; + } + + // ---------------------------------------------------------------------- + // Failover + // ---------------------------------------------------------------------- + protected void requestRollback(String exceptionMsg) { + jobWorker.requestRollback(exceptionMsg); + } + + public boolean isAlive() { + return this.thread.isAlive(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java new file mode 100644 index 00000000..3ae3c6fe --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java @@ -0,0 +1,20 @@ +package io.ray.streaming.runtime.worker.tasks; + +import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.core.processor.TwoInputProcessor; +import io.ray.streaming.runtime.worker.JobWorker; + +/** Input stream task with 2 inputs. Such as: join operator. */ +public class TwoInputStreamTask extends InputStreamTask { + + public TwoInputStreamTask( + Processor processor, + JobWorker jobWorker, + String leftStream, + String rightStream, + long lastCheckpointId) { + super(processor, jobWorker, lastCheckpointId); + ((TwoInputProcessor) (super.processor)).setLeftStream(leftStream); + ((TwoInputProcessor) (super.processor)).setRightStream(rightStream); + } +} diff --git a/streaming/java/streaming-runtime/src/main/resources/META-INF/services/io.ray.streaming.client.JobClient b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/io.ray.streaming.client.JobClient new file mode 100644 index 00000000..b592c731 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/io.ray.streaming.client.JobClient @@ -0,0 +1 @@ +io.ray.streaming.runtime.client.JobClientImpl \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java new file mode 100644 index 00000000..9a343f06 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java @@ -0,0 +1,40 @@ +package io.ray.streaming.runtime; + +import java.lang.reflect.Method; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; + +public abstract class BaseUnitTest { + + private static final Logger LOG = LoggerFactory.getLogger(BaseUnitTest.class); + + @BeforeClass + public void setUp() { + TestHelper.setUTFlag(); + } + + @AfterClass + public void tearDown() { + TestHelper.clearUTFlag(); + } + + @BeforeMethod + public void testBegin(Method method) { + LOG.info( + ">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>", + method.getDeclaringClass(), + method.getName()); + } + + @AfterMethod + public void testEnd(Method method) { + LOG.info( + ">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>", + method.getDeclaringClass(), + method.getName()); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/TestHelper.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/TestHelper.java new file mode 100644 index 00000000..af0aec2d --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/TestHelper.java @@ -0,0 +1,18 @@ +package io.ray.streaming.runtime; + +public class TestHelper { + + private static volatile boolean UT_FLAG = false; + + public static void setUTFlag() { + UT_FLAG = true; + } + + public static void clearUTFlag() { + UT_FLAG = false; + } + + public static boolean isUT() { + return UT_FLAG; + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/config/ConfigTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/config/ConfigTest.java new file mode 100644 index 00000000..0a7ca3dd --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/config/ConfigTest.java @@ -0,0 +1,64 @@ +package io.ray.streaming.runtime.config; + +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.config.global.CommonConfig; +import java.util.HashMap; +import java.util.Map; +import org.aeonbits.owner.ConfigFactory; +import org.nustaq.serialization.FSTConfiguration; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ConfigTest extends BaseUnitTest { + + @Test + public void testBaseFunc() { + // conf using + CommonConfig commonConfig = ConfigFactory.create(CommonConfig.class); + Assert.assertTrue(commonConfig.jobId().equals("default-job-id")); + + // override conf + Map customConf = new HashMap<>(); + customConf.put(CommonConfig.JOB_ID, "111"); + CommonConfig commonConfig2 = ConfigFactory.create(CommonConfig.class, customConf); + Assert.assertTrue(commonConfig2.jobId().equals("111")); + } + + @Test + public void testMapTransformation() { + Map conf = new HashMap<>(); + String testValue = "222"; + conf.put(CommonConfig.JOB_ID, testValue); + + StreamingConfig config = new StreamingConfig(conf); + Map wholeConfigMap = config.getMap(); + + Assert.assertTrue(wholeConfigMap.get(CommonConfig.JOB_ID).equals(testValue)); + } + + @Test + public void testCustomConfKeeping() { + Map conf = new HashMap<>(); + String customKey = "test_key"; + String customValue = "test_value"; + conf.put(customKey, customValue); + StreamingConfig config = new StreamingConfig(conf); + Assert.assertEquals(config.getMap().get(customKey), customValue); + } + + @Test + public void testSerialization() { + Map conf = new HashMap<>(); + String customKey = "test_key"; + String customValue = "test_value"; + conf.put(customKey, customValue); + StreamingConfig config = new StreamingConfig(conf); + + FSTConfiguration fstConf = FSTConfiguration.createDefaultConfiguration(); + byte[] configBytes = fstConf.asByteArray(config); + StreamingConfig deserializedConfig = (StreamingConfig) fstConf.asObject(configBytes); + + Assert.assertEquals(deserializedConfig.masterConfig.commonConfig.jobId(), "default-job-id"); + Assert.assertEquals(deserializedConfig.getMap().get(customKey), customValue); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java new file mode 100644 index 00000000..c6bb7967 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -0,0 +1,111 @@ +package io.ray.streaming.runtime.core.graph; + +import com.google.common.collect.Lists; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.api.stream.StreamSink; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.jobgraph.JobVertex; +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.config.master.ResourceConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.ResourceType; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ExecutionGraphTest extends BaseUnitTest { + + private static final Logger LOG = LoggerFactory.getLogger(ExecutionGraphTest.class); + + @Test + public void testBuildExecutionGraph() { + Map jobConf = new HashMap<>(); + StreamingConfig streamingConfig = new StreamingConfig(jobConf); + GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); + JobGraph jobGraph = buildJobGraph(); + jobGraph.getJobConfig().put("streaming.task.resource.cpu.limitation.enable", "true"); + + ExecutionGraph executionGraph = buildExecutionGraph(graphManager, jobGraph); + List executionJobVertices = executionGraph.getExecutionJobVertexList(); + + Assert.assertEquals(executionJobVertices.size(), jobGraph.getJobVertices().size()); + + int totalVertexNum = + jobGraph.getJobVertices().stream().mapToInt(JobVertex::getParallelism).sum(); + Assert.assertEquals(executionGraph.getAllExecutionVertices().size(), totalVertexNum); + Assert.assertEquals( + executionGraph.getAllExecutionVertices().size(), + executionGraph.getExecutionVertexIdGenerator().get()); + + executionGraph + .getAllExecutionVertices() + .forEach( + vertex -> { + Assert.assertNotNull(vertex.getStreamOperator()); + Assert.assertNotNull(vertex.getExecutionJobVertexName()); + Assert.assertNotNull(vertex.getVertexType()); + Assert.assertNotNull(vertex.getLanguage()); + Assert.assertEquals( + vertex.getExecutionVertexName(), + vertex.getExecutionJobVertexName() + "-" + vertex.getExecutionVertexIndex()); + }); + + int startIndex = 0; + ExecutionJobVertex upStream = executionJobVertices.get(startIndex); + ExecutionJobVertex downStream = executionJobVertices.get(startIndex + 1); + Assert.assertEquals(upStream.getOutputEdges().get(0).getTargetExecutionJobVertex(), downStream); + + List upStreamVertices = upStream.getExecutionVertices(); + List downStreamVertices = downStream.getExecutionVertices(); + upStreamVertices.forEach( + vertex -> { + Assert.assertEquals((double) vertex.getResource().get(ResourceType.CPU.name()), 2.0); + vertex + .getOutputEdges() + .forEach( + upStreamOutPutEdge -> { + Assert.assertTrue( + downStreamVertices.contains(upStreamOutPutEdge.getTargetExecutionVertex())); + }); + }); + } + + public static ExecutionGraph buildExecutionGraph(GraphManager graphManager) { + return graphManager.buildExecutionGraph(buildJobGraph()); + } + + public static ExecutionGraph buildExecutionGraph(GraphManager graphManager, JobGraph jobGraph) { + return graphManager.buildExecutionGraph(jobGraph); + } + + public static JobGraph buildJobGraph() { + StreamingContext streamingContext = StreamingContext.buildContext(); + DataStream dataStream = + DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("a", "b", "c")); + StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); + + Map jobConfig = new HashMap<>(); + jobConfig.put("key1", "value1"); + jobConfig.put("key2", "value2"); + jobConfig.put(ResourceConfig.TASK_RESOURCE_CPU, "2.0"); + jobConfig.put(ResourceConfig.TASK_RESOURCE_MEM, "2.0"); + + JobGraphBuilder jobGraphBuilder = + new JobGraphBuilder(Lists.newArrayList(streamSink), "test", jobConfig); + + return jobGraphBuilder.build(); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java new file mode 100644 index 00000000..35eeeddc --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java @@ -0,0 +1,100 @@ +package io.ray.streaming.runtime.demo; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.FilterFunction; +import io.ray.streaming.api.function.impl.MapFunction; +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.runtime.util.EnvUtil; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class HybridStreamTest { + + private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class); + + public static class Mapper1 implements MapFunction { + + @Override + public Object map(Object value) { + LOG.info("HybridStreamTest Mapper1 {}", value); + return value.toString(); + } + } + + public static class Filter1 implements FilterFunction { + + @Override + public boolean filter(Object value) throws Exception { + LOG.info("HybridStreamTest Filter1 {}", value); + return !value.toString().contains("b"); + } + } + + @Test(timeOut = 60000) + public void testHybridDataStream() throws Exception { + Ray.shutdown(); + Preconditions.checkArgument(EnvUtil.executeCommand(ImmutableList.of("ray", "stop"), 5)); + String sinkFileName = "/tmp/testHybridDataStream.txt"; + Files.deleteIfExists(Paths.get(sinkFileName)); + + StreamingContext context = StreamingContext.buildContext(); + DataStreamSource streamSource = + DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c")); + streamSource + .map(x -> x + x) + .asPythonStream() + .map("ray.streaming.tests.test_hybrid_stream", "map_func1") + .filter("ray.streaming.tests.test_hybrid_stream", "filter_func1") + .asJavaStream() + .sink( + (SinkFunction) + value -> { + LOG.info("HybridStreamTest: {}", value); + try { + if (!Files.exists(Paths.get(sinkFileName))) { + Files.createFile(Paths.get(sinkFileName)); + } + Files.write( + Paths.get(sinkFileName), + value.toString().getBytes(), + StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + context.execute("HybridStreamTestJob"); + int sleptTime = 0; + TimeUnit.SECONDS.sleep(3); + while (true) { + if (Files.exists(Paths.get(sinkFileName))) { + TimeUnit.SECONDS.sleep(3); + String text = String.join(", ", Files.readAllLines(Paths.get(sinkFileName))); + Assert.assertTrue(text.contains("a")); + Assert.assertFalse(text.contains("b")); + Assert.assertTrue(text.contains("c")); + LOG.info("Execution succeed"); + break; + } + sleptTime += 1; + if (sleptTime >= 60) { + throw new RuntimeException("Execution not finished"); + } + LOG.info("Wait finish..."); + TimeUnit.SECONDS.sleep(1); + } + context.stop(); + LOG.info("HybridStreamTest succeed"); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java new file mode 100644 index 00000000..26f679fc --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java @@ -0,0 +1,75 @@ +package io.ray.streaming.runtime.demo; + +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class UnionStreamTest { + + private static final Logger LOG = LoggerFactory.getLogger(UnionStreamTest.class); + + @Test(timeOut = 60000) + public void testUnionStream() throws Exception { + Ray.shutdown(); + String sinkFileName = "/tmp/testUnionStream.txt"; + Files.deleteIfExists(Paths.get(sinkFileName)); + + StreamingContext context = StreamingContext.buildContext(); + DataStreamSource streamSource1 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + DataStreamSource streamSource2 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + DataStreamSource streamSource3 = + DataStreamSource.fromCollection(context, Arrays.asList(1, 1)); + streamSource1 + .union(streamSource2, streamSource3) + .sink( + (SinkFunction) + value -> { + LOG.info("UnionStreamTest, sink: {}", value); + try { + if (!Files.exists(Paths.get(sinkFileName))) { + Files.createFile(Paths.get(sinkFileName)); + } + Files.write( + Paths.get(sinkFileName), + value.toString().getBytes(), + StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + context.execute("UnionStreamTest"); + int sleptTime = 0; + TimeUnit.SECONDS.sleep(3); + while (true) { + if (Files.exists(Paths.get(sinkFileName))) { + TimeUnit.SECONDS.sleep(3); + String text = String.join(", ", Files.readAllLines(Paths.get(sinkFileName))); + Assert.assertEquals(text, StringUtils.repeat("1", 6)); + LOG.info("Execution succeed"); + break; + } + sleptTime += 1; + if (sleptTime >= 60) { + throw new RuntimeException("Execution not finished"); + } + LOG.info("Wait finish..."); + TimeUnit.SECONDS.sleep(1); + } + context.stop(); + LOG.info("HybridStreamTest succeed"); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java new file mode 100644 index 00000000..dc75b2b0 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java @@ -0,0 +1,79 @@ +package io.ray.streaming.runtime.demo; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.FlatMapFunction; +import io.ray.streaming.api.function.impl.ReduceFunction; +import io.ray.streaming.api.function.impl.SinkFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.util.Config; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.Test; + +public class WordCountTest extends BaseUnitTest implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(WordCountTest.class); + + static Map wordCount = new ConcurrentHashMap<>(); + + @Test(timeOut = 60000) + public void testWordCount() { + Ray.shutdown(); + + StreamingContext streamingContext = StreamingContext.buildContext(); + Map config = new HashMap<>(); + config.put(Config.CHANNEL_TYPE, "MEMORY_CHANNEL"); + streamingContext.withConfig(config); + List text = new ArrayList<>(); + text.add("hello world eagle eagle eagle"); + DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text); + streamSource + .flatMap( + (FlatMapFunction) + (value, collector) -> { + String[] records = value.split(" "); + for (String record : records) { + collector.collect(new WordAndCount(record, 1)); + } + }) + .filter(pair -> !pair.word.contains("world")) + .keyBy(pair -> pair.word) + .reduce( + (ReduceFunction) + (oldValue, newValue) -> + new WordAndCount(oldValue.word, oldValue.count + newValue.count)) + .sink((SinkFunction) result -> wordCount.put(result.word, result.count)); + + streamingContext.execute("testWordCount"); + + ImmutableMap expected = ImmutableMap.of("eagle", 3, "hello", 1); + while (!wordCount.equals(expected)) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + LOG.warn("Got an exception while sleeping.", e); + } + } + streamingContext.stop(); + } + + private static class WordAndCount implements Serializable { + + public final String word; + public final Integer count; + + public WordAndCount(String key, Integer count) { + this.word = key; + this.count = count; + } + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java new file mode 100644 index 00000000..def36a2c --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java @@ -0,0 +1,34 @@ +package io.ray.streaming.runtime.master; + +import io.ray.api.Ray; +import io.ray.streaming.runtime.BaseUnitTest; +import java.util.HashMap; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class JobMasterTest extends BaseUnitTest { + + @BeforeMethod + public void init() { + // ray init + Ray.init(); + } + + @AfterMethod + public void tearDown() { + Ray.shutdown(); + } + + @Test + public void testCreation() { + JobMaster jobMaster = new JobMaster(new HashMap<>()); + Assert.assertNotNull(jobMaster.getRuntimeContext()); + Assert.assertNotNull(jobMaster.getConf()); + Assert.assertNull(jobMaster.getGraphManager()); + Assert.assertNull(jobMaster.getResourceManager()); + Assert.assertNull(jobMaster.getJobMasterActor()); + Assert.assertFalse(jobMaster.init(false)); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/jobscheduler/JobClientTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/jobscheduler/JobClientTest.java new file mode 100644 index 00000000..19c3b8d3 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/jobscheduler/JobClientTest.java @@ -0,0 +1,11 @@ +package io.ray.streaming.runtime.master.jobscheduler; + +import org.testng.annotations.Test; + +public class JobClientTest { + + @Test + public void testSchedule() { + // TODO (tianyi): need JobWorker Part to do this. + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java new file mode 100644 index 00000000..5f3e7db3 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java @@ -0,0 +1,54 @@ +package io.ray.streaming.runtime.master.resourcemanager; + +import io.ray.api.Ray; +import io.ray.api.id.UniqueId; +import io.ray.api.runtimecontext.NodeInfo; +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.config.global.CommonConfig; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.util.RayUtils; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class ResourceManagerTest extends BaseUnitTest { + + private static final Logger LOG = LoggerFactory.getLogger(ResourceManagerTest.class); + + private Object rayAsyncContext; + + @BeforeMethod + public void init() { + // ray init + Ray.init(); + rayAsyncContext = Ray.getAsyncContext(); + } + + @Test + public void testGcsMockedApi() { + Map nodeInfoMap = RayUtils.getAliveNodeInfoMap(); + Assert.assertEquals(nodeInfoMap.size(), 5); + } + + @Test(dependsOnMethods = "testGcsMockedApi") + public void testApi() { + Ray.setAsyncContext(rayAsyncContext); + + Map conf = new HashMap(); + conf.put(CommonConfig.JOB_NAME, "testApi"); + StreamingConfig config = new StreamingConfig(conf); + JobMasterRuntimeContext jobMasterRuntimeContext = new JobMasterRuntimeContext(config); + ResourceManager resourceManager = new ResourceManagerImpl(jobMasterRuntimeContext); + + // test register container + List containers = resourceManager.getRegisteredContainers(); + Assert.assertEquals(containers.size(), 5); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java new file mode 100644 index 00000000..ad25cb68 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java @@ -0,0 +1,73 @@ +package io.ray.streaming.runtime.master.resourcemanager.strategy; + +import io.ray.api.id.UniqueId; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; +import io.ray.streaming.runtime.core.graph.ExecutionGraphTest; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.ResourceType; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; +import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView; +import io.ray.streaming.runtime.master.resourcemanager.strategy.impl.PipelineFirstStrategy; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class PipelineFirstStrategyTest extends BaseUnitTest { + + private Logger LOG = LoggerFactory.getLogger(PipelineFirstStrategyTest.class); + + private List containers = new ArrayList<>(); + private ResourceAssignStrategy strategy; + + @BeforeClass + public void setUp() { + strategy = new PipelineFirstStrategy(); + for (int i = 0; i < 2; ++i) { + UniqueId uniqueId = UniqueId.randomId(); + Map resource = new HashMap<>(); + resource.put(ResourceType.CPU.getValue(), 4.0); + resource.put(ResourceType.MEM.getValue(), 4.0); + Container container = new Container("1.1.1." + i, uniqueId, "localhost" + i, resource); + container.getAvailableResources().put(container.getName(), 500.0); + containers.add(container); + } + } + + @AfterMethod + public void tearDown() { + reset(); + } + + private void reset() { + containers = new ArrayList<>(); + strategy = null; + } + + @Test + public void testResourceAssignment() { + strategy = new PipelineFirstStrategy(); + Assert.assertEquals( + ResourceAssignStrategyType.PIPELINE_FIRST_STRATEGY.getName(), strategy.getName()); + + Map jobConf = new HashMap<>(); + StreamingConfig streamingConfig = new StreamingConfig(jobConf); + GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); + JobGraph jobGraph = ExecutionGraphTest.buildJobGraph(); + ExecutionGraph executionGraph = ExecutionGraphTest.buildExecutionGraph(graphManager, jobGraph); + ResourceAssignmentView assignmentView = strategy.assignResource(containers, executionGraph); + Assert.assertNotNull(assignmentView); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java new file mode 100644 index 00000000..dfd41252 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java @@ -0,0 +1,47 @@ +package io.ray.streaming.runtime.python; + +import static org.testng.Assert.assertEquals; + +import io.ray.streaming.api.stream.StreamSink; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + +public class PythonGatewayTest { + + @Test + public void testPythonGateway() { + MsgPackSerializer serializer = new MsgPackSerializer(); + PythonGateway gateway = new PythonGateway(); + gateway.createStreamingContext(); + Map config = new HashMap<>(); + config.put("k1", "v1"); + gateway.withConfig(serializer.serialize(config)); + byte[] mockPySource = new byte[0]; + Object source = serializer.deserialize(gateway.createPythonStreamSource(mockPySource)); + byte[] mockPyFunc = new byte[0]; + Object mapPyFunc = serializer.deserialize(gateway.createPyFunc(mockPyFunc)); + Object mapStream = + serializer.deserialize( + gateway.callMethod(serializer.serialize(Arrays.asList(source, "map", mapPyFunc)))); + byte[] mockPyPartition = new byte[0]; + Object partition = serializer.deserialize(gateway.createPyPartition(mockPyPartition)); + Object partitionedStream = + serializer.deserialize( + gateway.callMethod( + serializer.serialize(Arrays.asList(mapStream, "partitionBy", partition)))); + byte[] mockSinkFunc = new byte[0]; + Object sinkPyFunc = serializer.deserialize(gateway.createPyFunc(mockSinkFunc)); + gateway.callMethod(serializer.serialize(Arrays.asList(partitionedStream, "sink", sinkPyFunc))); + List streamSinks = gateway.getStreamingContext().getStreamSinks(); + assertEquals(streamSinks.size(), 1); + JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(streamSinks, "py_job"); + JobGraph jobGraph = jobGraphBuilder.build(); + jobGraph.printJobGraph(); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java new file mode 100644 index 00000000..a3584d2e --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java @@ -0,0 +1,26 @@ +package io.ray.streaming.runtime.serialization; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.testng.annotations.Test; + +public class CrossLangSerializerTest { + + @Test + @SuppressWarnings("unchecked") + public void testSerialize() { + CrossLangSerializer serializer = new CrossLangSerializer(); + Record record = new Record("value"); + record.setStream("stream1"); + assertTrue( + EqualsBuilder.reflectionEquals( + record, serializer.deserialize(serializer.serialize(record)))); + KeyRecord keyRecord = new KeyRecord("key", "value"); + keyRecord.setStream("stream2"); + assertEquals(keyRecord, serializer.deserialize(serializer.serialize(keyRecord))); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java new file mode 100644 index 00000000..4a574e4d --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java @@ -0,0 +1,49 @@ +package io.ray.streaming.runtime.serialization; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + +@SuppressWarnings("unchecked") +public class MsgPackSerializerTest { + + @Test + public void testSerializeByte() { + MsgPackSerializer serializer = new MsgPackSerializer(); + + assertEquals(serializer.deserialize(serializer.serialize((byte) 1)), (byte) 1); + } + + @Test + public void testSerialize() { + MsgPackSerializer serializer = new MsgPackSerializer(); + + assertEquals(serializer.deserialize(serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE); + assertEquals( + serializer.deserialize(serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE); + assertEquals(serializer.deserialize(serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE); + + Map map = new HashMap(); + List list = new ArrayList<>(); + list.add(null); + list.add(true); + list.add(1.0d); + list.add("str"); + map.put("k1", "value1"); + map.put("k2", new HashMap<>()); + map.put("k3", list); + byte[] bytes = serializer.serialize(map); + Object o = serializer.deserialize(bytes); + assertEquals(o, map); + + byte[] binary = {1, 2, 3, 4}; + assertTrue( + Arrays.equals(binary, (byte[]) (serializer.deserialize(serializer.serialize(binary))))); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java new file mode 100644 index 00000000..3ff425d2 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -0,0 +1,235 @@ +package io.ray.streaming.runtime.streamingqueue; + +import com.google.common.collect.ImmutableMap; +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.FlatMapFunction; +import io.ray.streaming.api.function.impl.ReduceFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.util.EnvUtil; +import io.ray.streaming.util.Config; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class StreamingQueueTest extends BaseUnitTest implements Serializable { + + private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class); + + static { + EnvUtil.loadNativeLibraries(); + } + + @org.testng.annotations.BeforeSuite + public void suiteSetUp() throws Exception { + LOGGER.info("Do set up"); + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("StreamingQueueTest pid: {}", pid); + LOGGER.info("java.library.path = {}", System.getProperty("java.library.path")); + } + + @org.testng.annotations.AfterSuite + public void suiteTearDown() throws Exception { + LOGGER.warn("Do tear down"); + } + + @BeforeMethod + void beforeMethod() { + LOGGER.info("beforeTest"); + Ray.shutdown(); + System.setProperty("ray.head-args.0", "--num-cpus=4"); + System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}"); + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + Ray.init(); + } + + @AfterMethod + void afterMethod() { + LOGGER.info("afterTest"); + Ray.shutdown(); + System.clearProperty("ray.run-mode"); + System.clearProperty("ray.head-args.0"); + System.clearProperty("ray.head-args.1"); + } + + @Test(timeOut = 300000) + public void testReaderWriter() { + LOGGER.info( + "StreamingQueueTest.testReaderWriter run-mode: {}", System.getProperty("ray.run-mode")); + Ray.shutdown(); + System.setProperty("ray.head-args.0", "--num-cpus=4"); + System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}"); + + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + // ray init + Ray.init(); + + ActorHandle writerActor = Ray.actor(WriterWorker::new, "writer").remote(); + ActorHandle readerActor = Ray.actor(ReaderWorker::new, "reader").remote(); + + LOGGER.info( + "call getName on writerActor: {}", writerActor.task(WriterWorker::getName).remote().get()); + LOGGER.info( + "call getName on readerActor: {}", readerActor.task(ReaderWorker::getName).remote().get()); + + // LOGGER.info(writerActor.task(WriterWorker::testCallReader, readerActor).remote().get()); + List outputQueueList = new ArrayList<>(); + List inputQueueList = new ArrayList<>(); + int queueNum = 2; + for (int i = 0; i < queueNum; ++i) { + String qid = ChannelId.genRandomIdStr(); + LOGGER.info("getRandomQueueId: {}", qid); + inputQueueList.add(qid); + outputQueueList.add(qid); + readerActor.getId(); + } + + final int msgCount = 100; + readerActor.task(ReaderWorker::init, inputQueueList, writerActor, msgCount).remote(); + try { + Thread.sleep(1000); + } catch (Exception e) { + e.printStackTrace(); + } + writerActor.task(WriterWorker::init, outputQueueList, readerActor, msgCount).remote(); + + long time = 0; + while (time < 20000 + && readerActor.task(ReaderWorker::getTotalMsg).remote().get() < msgCount * queueNum) { + try { + Thread.sleep(1000); + time += 1000; + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + Assert.assertEquals( + readerActor.task(ReaderWorker::getTotalMsg).remote().get().intValue(), msgCount * queueNum); + } + + @Test(timeOut = 60000) + public void testWordCount() { + Ray.shutdown(); + System.setProperty("ray.head-args.0", "--num-cpus=4"); + System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}"); + + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + // ray init + Ray.init(); + LOGGER.info("testWordCount"); + LOGGER.info( + "StreamingQueueTest.testWordCount run-mode: {}", System.getProperty("ray.run-mode")); + String resultFile = "/tmp/io.ray.streaming.runtime.streamingqueue.testWordCount.txt"; + deleteResultFile(resultFile); + + Map wordCount = new ConcurrentHashMap<>(); + StreamingContext streamingContext = StreamingContext.buildContext(); + Map config = new HashMap<>(); + config.put(Config.CHANNEL_TYPE, "NATIVE_CHANNEL"); + config.put(Config.CHANNEL_SIZE, "100000"); + streamingContext.withConfig(config); + List text = new ArrayList<>(); + text.add("hello world eagle eagle eagle"); + DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text); + streamSource + .flatMap( + (FlatMapFunction) + (value, collector) -> { + String[] records = value.split(" "); + for (String record : records) { + collector.collect(new WordAndCount(record, 1)); + } + }) + .keyBy(pair -> pair.word) + .reduce( + (ReduceFunction) + (oldValue, newValue) -> { + LOGGER.info("reduce: {} {}", oldValue, newValue); + return new WordAndCount(oldValue.word, oldValue.count + newValue.count); + }) + .sink( + s -> { + LOGGER.info("sink {} {}", s.word, s.count); + wordCount.put(s.word, s.count); + serializeResultToFile(resultFile, wordCount); + }); + + streamingContext.execute("testSQWordCount"); + + Map checkWordCount = + (Map) deserializeResultFromFile(resultFile); + // Sleep until the count for every word is computed. + while (checkWordCount == null || checkWordCount.size() < 3) { + LOGGER.info("sleep"); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + LOGGER.warn("Got an exception while sleeping.", e); + } + checkWordCount = (Map) deserializeResultFromFile(resultFile); + } + LOGGER.info("check"); + Assert.assertEquals(checkWordCount, ImmutableMap.of("eagle", 3, "hello", 1, "world", 1)); + } + + private void serializeResultToFile(String fileName, Object obj) { + try { + ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName)); + out.writeObject(obj); + } catch (Exception e) { + LOGGER.error(String.valueOf(e)); + } + } + + private Object deserializeResultFromFile(String fileName) { + Map checkWordCount = null; + try { + ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName)); + checkWordCount = (Map) in.readObject(); + Assert.assertEquals(checkWordCount, ImmutableMap.of("eagle", 3, "hello", 1, "world", 1)); + } catch (Exception e) { + LOGGER.error(String.valueOf(e)); + } + return checkWordCount; + } + + private static class WordAndCount implements Serializable { + + public final String word; + public final Integer count; + + public WordAndCount(String key, Integer count) { + this.word = key; + this.count = count; + } + } + + private void deleteResultFile(String path) { + File file = new File(path); + file.deleteOnExit(); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java new file mode 100644 index 00000000..59201e45 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java @@ -0,0 +1,284 @@ +package io.ray.streaming.runtime.streamingqueue; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.runtime.functionmanager.JavaFunctionDescriptor; +import io.ray.streaming.runtime.config.StreamingWorkerConfig; +import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder; +import io.ray.streaming.runtime.transfer.DataReader; +import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.TransferHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.message.DataMessage; +import io.ray.streaming.util.Config; +import java.lang.management.ManagementFactory; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; + +public class Worker { + + private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); + + protected TransferHandler transferHandler = null; + + public Worker() { + transferHandler = new TransferHandler(); + } + + public void onReaderMessage(byte[] buffer) { + transferHandler.onReaderMessage(buffer); + } + + public byte[] onReaderMessageSync(byte[] buffer) { + return transferHandler.onReaderMessageSync(buffer); + } + + public void onWriterMessage(byte[] buffer) { + transferHandler.onWriterMessage(buffer); + } + + public byte[] onWriterMessageSync(byte[] buffer) { + return transferHandler.onWriterMessageSync(buffer); + } +} + +class ReaderWorker extends Worker { + + private static final Logger LOGGER = LoggerFactory.getLogger(ReaderWorker.class); + + private String name = null; + private List inputQueueList = null; + List inputActors = new ArrayList<>(); + private DataReader dataReader = null; + private long handler = 0; + private ActorHandle peerActor = null; + private int msgCount = 0; + private int totalMsg = 0; + + public ReaderWorker(String name) { + LOGGER.info("ReaderWorker constructor"); + this.name = name; + } + + public String getName() { + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("pid: {} name: {}", pid, name); + return name; + } + + public String testRayCall() { + LOGGER.info("testRayCall called"); + return "testRayCall"; + } + + public boolean init(List inputQueueList, ActorHandle peer, int msgCount) { + + this.inputQueueList = inputQueueList; + this.peerActor = peer; + this.msgCount = msgCount; + + LOGGER.info("ReaderWorker init"); + LOGGER.info("java.library.path = {}", System.getProperty("java.library.path")); + + for (String queue : this.inputQueueList) { + inputActors.add(this.peerActor); + LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId()); + } + + Map conf = new HashMap<>(); + + conf.put(Config.CHANNEL_TYPE, "NATIVE_CHANNEL"); + conf.put(Config.CHANNEL_SIZE, "100000"); + conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); + ChannelCreationParametersBuilder.setJavaWriterFunctionDesc( + new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B")); + StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); + dataReader = new DataReader(inputQueueList, inputActors, new HashMap<>(), workerConfig); + + // Should not GetBundle in RayCall thread + Thread readThread = + new Thread( + Ray.wrapRunnable( + new Runnable() { + @Override + public void run() { + consume(); + } + })); + readThread.start(); + + LOGGER.info("ReaderWorker init done"); + + return true; + } + + public final void consume() { + + int checkPointId = 1; + for (int i = 0; i < msgCount * inputQueueList.size(); ++i) { + DataMessage dataMessage = (DataMessage) dataReader.read(100); + + if (dataMessage == null) { + LOGGER.error("dataMessage is null"); + i--; + continue; + } + + int bufferSize = dataMessage.body().remaining(); + int dataSize = dataMessage.body().getInt(); + + // check size + LOGGER.info( + "capacity {} bufferSize {} dataSize {}", + dataMessage.body().capacity(), + bufferSize, + dataSize); + Assert.assertEquals(bufferSize, dataSize); + if (dataMessage instanceof DataMessage) { + if (LOGGER.isInfoEnabled()) { + LOGGER.info("{} : {} message.", i, dataMessage.toString()); + } + // check content + for (int j = 0; j < dataSize - 4; ++j) { + Assert.assertEquals(dataMessage.body().get(), (byte) j); + } + } else { + LOGGER.error("unknown message type"); + Assert.fail(); + } + + totalMsg++; + } + + LOGGER.info("ReaderWorker consume data done."); + } + + void onQueueTransfer(long handler, byte[] buffer) {} + + public boolean done() { + return totalMsg == msgCount; + } + + public int getTotalMsg() { + return totalMsg; + } +} + +class WriterWorker extends Worker { + + private static final Logger LOGGER = LoggerFactory.getLogger(WriterWorker.class); + + private String name = null; + private List outputQueueList = null; + List outputActors = new ArrayList<>(); + DataWriter dataWriter = null; + ActorHandle peerActor = null; + int msgCount = 0; + + public WriterWorker(String name) { + this.name = name; + } + + public String getName() { + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("pid: {} name: {}", pid, name); + return name; + } + + public String testCallReader(ActorHandle readerActor) { + String name = readerActor.task(ReaderWorker::getName).remote().get(); + LOGGER.info("testCallReader: {}", name); + return name; + } + + public boolean init(List outputQueueList, ActorHandle peer, int msgCount) { + + this.outputQueueList = outputQueueList; + this.peerActor = peer; + this.msgCount = msgCount; + + LOGGER.info("WriterWorker init:"); + + for (String queue : this.outputQueueList) { + outputActors.add(this.peerActor); + LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId()); + } + + int count = 3; + while (count-- != 0) { + peer.task(ReaderWorker::testRayCall).remote().get(); + } + + try { + Thread.sleep(2 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + Map conf = new HashMap<>(); + + conf.put(Config.CHANNEL_TYPE, "NATIVE_CHANNEL"); + conf.put(Config.CHANNEL_SIZE, "100000"); + conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); + ChannelCreationParametersBuilder.setJavaReaderFunctionDesc( + new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B")); + StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); + dataWriter = new DataWriter(outputQueueList, outputActors, new HashMap<>(), workerConfig); + Thread writerThread = + new Thread( + Ray.wrapRunnable( + new Runnable() { + @Override + public void run() { + produce(); + } + })); + writerThread.start(); + + LOGGER.info("WriterWorker init done"); + return true; + } + + public final void produce() { + + int checkPointId = 1; + Random random = new Random(); + this.msgCount = 100; + for (int i = 0; i < this.msgCount; ++i) { + for (int j = 0; j < outputQueueList.size(); ++j) { + LOGGER.info("WriterWorker produce"); + int dataSize = (random.nextInt(100)) + 10; + if (LOGGER.isInfoEnabled()) { + LOGGER.info("dataSize: {}", dataSize); + } + ByteBuffer bb = ByteBuffer.allocate(dataSize); + bb.putInt(dataSize); + for (int k = 0; k < dataSize - 4; ++k) { + bb.put((byte) k); + } + + bb.clear(); + ChannelId qid = ChannelId.from(outputQueueList.get(j)); + dataWriter.write(qid, bb); + } + } + try { + Thread.sleep(20 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java new file mode 100644 index 00000000..0a91d453 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.transfer; + +import static org.testng.Assert.assertEquals; + +import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.util.EnvUtil; +import org.testng.annotations.Test; + +public class ChannelIdTest extends BaseUnitTest { + + static { + EnvUtil.loadNativeLibraries(); + } + + @Test + public void testIdStrToBytes() { + String idStr = ChannelId.genRandomIdStr(); + assertEquals(idStr.length(), ChannelId.ID_LENGTH * 2); + assertEquals(ChannelId.idStrToBytes(idStr).length, ChannelId.ID_LENGTH); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/Mockitools.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/Mockitools.java new file mode 100644 index 00000000..eb48f169 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/Mockitools.java @@ -0,0 +1,73 @@ +package io.ray.streaming.runtime.util; + +import io.ray.api.id.UniqueId; +import io.ray.api.runtimecontext.NodeInfo; +import io.ray.streaming.runtime.core.resource.ResourceType; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.powermock.api.mockito.PowerMockito; + +/** Mockitools is a tool based on powermock and mokito to mock external service api */ +public class Mockitools { + + /** Mock GCS get node info api */ + public static void mockGscApi() { + PowerMockito.mockStatic(RayUtils.class); + PowerMockito.when(RayUtils.getAliveNodeInfoMap()) + .thenReturn(mockGetNodeInfoMap(mockGetAllNodeInfo())); + } + + /** Mock get all node info from GCS */ + public static List mockGetAllNodeInfo() { + List nodeInfos = new LinkedList<>(); + + for (int i = 1; i <= 5; i++) { + Map resources = new HashMap<>(); + resources.put("MEM", 16.0); + switch (i) { + case 1: + resources.put(ResourceType.CPU.getValue(), 3.0); + break; + case 2: + case 3: + case 4: + resources.put(ResourceType.CPU.getValue(), 4.0); + break; + case 5: + resources.put(ResourceType.CPU.getValue(), 2.0); + break; + } + + nodeInfos.add(mockNodeInfo(i, resources)); + } + return nodeInfos; + } + + /** + * Mock get node info map + * + * @param nodeInfos all node infos fetched from GCS + * @return node info map, key is node unique id, value is node info + */ + public static Map mockGetNodeInfoMap(List nodeInfos) { + return nodeInfos.stream() + .filter(nodeInfo -> nodeInfo.isAlive) + .collect(Collectors.toMap(nodeInfo -> nodeInfo.nodeId, nodeInfo -> nodeInfo)); + } + + private static NodeInfo mockNodeInfo(int i, Map resources) { + return new NodeInfo( + createNodeId(i), "localhost" + i, "localhost" + i, -1, "", "", true, resources); + } + + private static UniqueId createNodeId(int id) { + byte[] nodeIdBytes = new byte[UniqueId.LENGTH]; + for (int byteIndex = 0; byteIndex < UniqueId.LENGTH; ++byteIndex) { + nodeIdBytes[byteIndex] = String.valueOf(id).getBytes()[0]; + } + return new UniqueId(nodeIdBytes); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/ReflectionUtilsTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/ReflectionUtilsTest.java new file mode 100644 index 00000000..570f067d --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/util/ReflectionUtilsTest.java @@ -0,0 +1,35 @@ +package io.ray.streaming.runtime.util; + +import static org.testng.Assert.assertEquals; + +import java.io.Serializable; +import java.util.Collections; +import org.testng.annotations.Test; + +public class ReflectionUtilsTest { + + static class Foo implements Serializable { + + public void f1() {} + + public void f2() {} + + public void f2(boolean a1) {} + } + + @Test + public void testFindMethod() throws NoSuchMethodException { + assertEquals(Foo.class.getDeclaredMethod("f1"), ReflectionUtils.findMethod(Foo.class, "f1")); + } + + @Test + public void testFindMethods() { + assertEquals(ReflectionUtils.findMethods(Foo.class, "f2").size(), 2); + } + + @Test + public void testGetAllInterfaces() { + assertEquals( + ReflectionUtils.getAllInterfaces(Foo.class), Collections.singletonList(Serializable.class)); + } +} diff --git a/streaming/java/streaming-runtime/src/test/resources/log4j.properties b/streaming/java/streaming-runtime/src/test/resources/log4j.properties new file mode 100644 index 00000000..8d40bd19 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/resources/log4j.properties @@ -0,0 +1,6 @@ +log4j.rootLogger=INFO, stdout +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SS} %-4p %c{1}:%L [%t] - %m%n diff --git a/streaming/java/streaming-runtime/src/test/resources/ray.conf b/streaming/java/streaming-runtime/src/test/resources/ray.conf new file mode 100644 index 00000000..fdc897fa --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/resources/ray.conf @@ -0,0 +1,3 @@ +ray { + run-mode = SINGLE_PROCESS +} diff --git a/streaming/java/streaming-state/pom_template.xml b/streaming/java/streaming-state/pom_template.xml new file mode 100644 index 00000000..f6fc6c85 --- /dev/null +++ b/streaming/java/streaming-state/pom_template.xml @@ -0,0 +1,22 @@ + + {auto_gen_header} + + + ray-streaming + io.ray + 2.0.0-SNAPSHOT + + 4.0.0 + + streaming-state + ray streaming state + ray streaming state + jar + + + {generated_bzl_deps} + + + diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/KeyValueState.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/KeyValueState.java new file mode 100644 index 00000000..69a7931a --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/KeyValueState.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state; + +/** Key Value State interface. */ +public interface KeyValueState { + + /** get value from state */ + V get(K key); + + /** put key and value into state */ + void put(K k, V v); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/PartitionRecord.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/PartitionRecord.java new file mode 100644 index 00000000..50fd3fe4 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/PartitionRecord.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state; + +import java.io.Serializable; + +/** value record for partition. */ +public class PartitionRecord implements Serializable { + + /** The partition number of the partitioned value. */ + private int partitionID; + + private T value; + + public PartitionRecord() {} + + public PartitionRecord(int partitionID, T value) { + this.partitionID = partitionID; + this.value = value; + } + + public T getValue() { + return value; + } + + public int getPartitionID() { + return partitionID; + } + + public void setPartitionID(int partitionID) { + this.partitionID = partitionID; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateException.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateException.java new file mode 100644 index 00000000..f3eb5b2c --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state; + +/** RuntimeException wrapper, indicating the exceptions occurs in states. */ +public class StateException extends RuntimeException { + + public StateException(Throwable t) { + super(t); + } + + public StateException(String msg) { + super(msg); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateStoreManager.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateStoreManager.java new file mode 100644 index 00000000..b1f3c547 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StateStoreManager.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state; + +/** + * TransactionState interface. + * + *

Streaming State should implement transaction in case of failure, which in our case is four + * default method, finish, commit, ackCommit, rollback. + */ +public interface StateStoreManager { + + /** + * The finish method is used when the batched data is all saved in state. Normally, serialization + * job is done here. + */ + void finish(long checkpointId); + + /** + * The commit method is used for persistent, and can be used in another thread to reach async + * state commit. Normally, data persistent is done here. + */ + void commit(long checkpointId); + + /** + * The ackCommit method is used for cleaning the last checkpoint, and must be called after commit + * in the same thread. + */ + void ackCommit(long checkpointId, long timeStamp); + + /** The rollback method is used for recovering the checkpoint. */ + void rollBack(long checkpointId); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StorageRecord.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StorageRecord.java new file mode 100644 index 00000000..dcce1d01 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/StorageRecord.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state; + +import java.io.Serializable; + +/** This Class contains a record with some checkpointId. */ +public class StorageRecord implements Serializable { + + private long checkpointId; + private T value; + + public StorageRecord() {} + + public StorageRecord(long checkpointId, T value) { + this.checkpointId = checkpointId; + this.value = value; + } + + public T getValue() { + return value; + } + + public long getCheckpointId() { + return checkpointId; + } + + public void setCheckpointId(long checkpointId) { + this.checkpointId = checkpointId; + } + + @Override + public String toString() { + if (value != null) { + return "checkpointId:" + checkpointId + ", value:" + value; + } else { + return "checkpointId:" + checkpointId + ", value:null"; + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractKeyStateBackend.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractKeyStateBackend.java new file mode 100644 index 00000000..43c4c781 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractKeyStateBackend.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import io.ray.streaming.state.StateStoreManager; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor.StateType; +import io.ray.streaming.state.keystate.state.proxy.ListStateStoreManagerProxy; +import io.ray.streaming.state.keystate.state.proxy.MapStateStoreManagerProxy; +import io.ray.streaming.state.keystate.state.proxy.ValueStateStoreManagerProxy; +import io.ray.streaming.state.store.KeyMapStore; +import io.ray.streaming.state.store.KeyValueStore; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Transaction support primitive operations like finish, commit, ackcommit and rollback. + * + *

State value modification is not thread safe! By default, every processing thread has its own + * space to handle state. + */ +public abstract class AbstractKeyStateBackend implements StateStoreManager { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractKeyStateBackend.class); + + protected long currentCheckpointId; + protected Object currentKey; + protected int keyGroupIndex = -1; + protected Map valueManagerProxyHashMap = new HashMap<>(); + protected Map listManagerProxyHashMap = new HashMap<>(); + protected Map mapManagerProxyHashMap = new HashMap<>(); + protected Set descNamespace; + + /** tablename, KeyValueStore key, checkpointId, content */ + protected Map>> backStorageCache; + + private AbstractStateBackend backend; + + public AbstractKeyStateBackend(AbstractStateBackend backend) { + this.backStorageCache = new HashMap<>(); + this.backend = backend; + this.descNamespace = new HashSet<>(); + } + + public void put(AbstractStateDescriptor descriptor, K key, T value) { + String desc = descriptor.getIdentify(); + if (descriptor.getStateType() == StateType.VALUE) { + if (this.valueManagerProxyHashMap.containsKey(desc)) { + valueManagerProxyHashMap.get(desc).put((String) key, value); + } + } else if (descriptor.getStateType() == StateType.LIST) { + if (this.listManagerProxyHashMap.containsKey(desc)) { + listManagerProxyHashMap.get(desc).put((String) key, value); + } + } else if (descriptor.getStateType() == StateType.MAP) { + if (this.mapManagerProxyHashMap.containsKey(desc)) { + mapManagerProxyHashMap.get(desc).put((String) key, value); + } + } + } + + public T get(AbstractStateDescriptor descriptor, K key) { + String desc = descriptor.getIdentify(); + if (descriptor.getStateType() == StateType.VALUE) { + if (this.valueManagerProxyHashMap.containsKey(desc)) { + return (T) valueManagerProxyHashMap.get(desc).get((String) key); + } + } else if (descriptor.getStateType() == StateType.LIST) { + if (this.listManagerProxyHashMap.containsKey(desc)) { + return (T) listManagerProxyHashMap.get(desc).get((String) key); + } + } else if (descriptor.getStateType() == StateType.MAP) { + if (this.mapManagerProxyHashMap.containsKey(desc)) { + return (T) mapManagerProxyHashMap.get(desc).get((String) key); + } + } + return null; + } + + @Override + public void finish(long checkpointId) { + for (Entry entry : valueManagerProxyHashMap.entrySet()) { + entry.getValue().finish(checkpointId); + } + for (Entry entry : listManagerProxyHashMap.entrySet()) { + entry.getValue().finish(checkpointId); + } + for (Entry entry : mapManagerProxyHashMap.entrySet()) { + entry.getValue().finish(checkpointId); + } + } + + @Override + public void commit(long checkpointId) { + for (Entry entry : valueManagerProxyHashMap.entrySet()) { + entry.getValue().commit(checkpointId); + } + for (Entry entry : listManagerProxyHashMap.entrySet()) { + entry.getValue().commit(checkpointId); + } + for (Entry entry : mapManagerProxyHashMap.entrySet()) { + entry.getValue().commit(checkpointId); + } + } + + @Override + public void ackCommit(long checkpointId, long timeStamp) { + for (Entry entry : valueManagerProxyHashMap.entrySet()) { + entry.getValue().ackCommit(checkpointId, timeStamp); + } + for (Entry entry : listManagerProxyHashMap.entrySet()) { + entry.getValue().ackCommit(checkpointId, timeStamp); + } + for (Entry entry : mapManagerProxyHashMap.entrySet()) { + entry.getValue().ackCommit(checkpointId, timeStamp); + } + } + + @Override + public void rollBack(long checkpointId) { + for (Entry entry : valueManagerProxyHashMap.entrySet()) { + LOG.warn("backend rollback:{},{}", entry.getKey(), checkpointId); + entry.getValue().rollBack(checkpointId); + } + for (Entry entry : listManagerProxyHashMap.entrySet()) { + LOG.warn("backend rollback:{},{}", entry.getKey(), checkpointId); + entry.getValue().rollBack(checkpointId); + } + for (Entry entry : mapManagerProxyHashMap.entrySet()) { + LOG.warn("backend rollback:{},{}", entry.getKey(), checkpointId); + entry.getValue().rollBack(checkpointId); + } + } + + public KeyValueStore> getBackStorage(String tableName) { + if (this.backStorageCache.containsKey(tableName)) { + return this.backStorageCache.get(tableName); + } else { + KeyMapStore ikvStore = this.backend.getKeyMapStore(tableName); + this.backStorageCache.put(tableName, ikvStore); + return ikvStore; + } + } + + public KeyValueStore> getBackStorage( + AbstractStateDescriptor stateDescriptor) { + String tableName = this.backend.getTableName(stateDescriptor); + return getBackStorage(tableName); + } + + public StateStrategy getStateStrategy() { + return this.backend.getStateStrategy(); + } + + public BackendType getBackendType() { + return this.backend.getBackendType(); + } + + public Object getCurrentKey() { + return this.currentKey; + } + + public abstract void setCurrentKey(Object currentKey); + + public long getCheckpointId() { + return this.currentCheckpointId; + } + + public void setCheckpointId(long checkpointId) { + this.currentCheckpointId = checkpointId; + } + + public void setContext(long checkpointId, Object currentKey) { + setCheckpointId(checkpointId); + setCurrentKey(currentKey); + } + + public AbstractStateBackend getBackend() { + return backend; + } + + public int getKeyGroupIndex() { + return this.keyGroupIndex; + } + + public void setKeyGroupIndex(int keyGroupIndex) { + this.keyGroupIndex = keyGroupIndex; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractStateBackend.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractStateBackend.java new file mode 100644 index 00000000..1f040828 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/AbstractStateBackend.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import static io.ray.streaming.state.config.ConfigKey.DELIMITER; + +import io.ray.streaming.state.config.ConfigKey; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor; +import io.ray.streaming.state.serialization.KeyMapStoreSerializer; +import io.ray.streaming.state.store.KeyMapStore; +import io.ray.streaming.state.store.KeyValueStore; +import java.io.Serializable; +import java.util.Map; + +/** This class is the abstract class for different kinds of StateBackend. */ +public abstract class AbstractStateBackend implements Serializable { + + protected final Map config; + protected final StateStrategy stateStrategy; + + protected final BackendType backendType; + protected int keyGroupIndex = -1; + + protected AbstractStateBackend(Map config) { + this.stateStrategy = StateStrategy.getEnum(ConfigKey.getStateStrategyEnum(config)); + this.backendType = BackendType.getEnum(ConfigKey.getBackendType(config)); + this.config = config; + } + + public abstract KeyValueStore getKeyValueStore(String tableName); + + public abstract KeyMapStore getKeyMapStore(String tableName); + + public abstract KeyMapStore getKeyMapStore( + String tableName, KeyMapStoreSerializer keyMapStoreSerializer); + + public BackendType getBackendType() { + return backendType; + } + + public StateStrategy getStateStrategy() { + return stateStrategy; + } + + public String getTableName(AbstractStateDescriptor stateDescriptor) { + return stateDescriptor.getTableName(); + } + + public String getStateKey(String descName, String currentKey) { + return descName + DELIMITER + currentKey; + } + + public void setKeyGroupIndex(int keyGroupIndex) { + this.keyGroupIndex = keyGroupIndex; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/BackendType.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/BackendType.java new file mode 100644 index 00000000..7c63d3e6 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/BackendType.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +/** Backend Types. */ +public enum BackendType { + /** Saving the state values in memory */ + MEMORY; + + /** get the enum from input string value, ignoring the case */ + public static BackendType getEnum(String value) { + for (BackendType v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } + } + return MEMORY; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/KeyStateBackend.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/KeyStateBackend.java new file mode 100644 index 00000000..bbf7619e --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/KeyStateBackend.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.KeyGroupAssignment; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import io.ray.streaming.state.keystate.state.proxy.ListStateStoreManagerProxy; +import io.ray.streaming.state.keystate.state.proxy.MapStateStoreManagerProxy; +import io.ray.streaming.state.keystate.state.proxy.ValueStateStoreManagerProxy; + +/** + * key state backend manager, managing different kinds of states in different thread. This class is + * not thread safe. + */ +public class KeyStateBackend extends AbstractKeyStateBackend { + + protected final int numberOfKeyGroups; + protected final KeyGroup keyGroup; + + public KeyStateBackend( + int numberOfKeyGroups, KeyGroup keyGroup, AbstractStateBackend abstractStateBackend) { + super(abstractStateBackend); + this.numberOfKeyGroups = numberOfKeyGroups; + this.keyGroup = keyGroup; + } + + /** get new value state proxy */ + protected ValueStateStoreManagerProxy newValueStateProxy( + ValueStateDescriptor stateDescriptor) { + return new ValueStateStoreManagerProxy<>(this, stateDescriptor); + } + + public ValueState getValueState(ValueStateDescriptor stateDescriptor) { + String desc = stateDescriptor.getIdentify(); + if (valueManagerProxyHashMap.containsKey(desc)) { + return valueManagerProxyHashMap.get(desc).getValueState(); + } else { + ValueStateStoreManagerProxy valueStateProxy = newValueStateProxy(stateDescriptor); + valueManagerProxyHashMap.put(desc, valueStateProxy); + return valueStateProxy.getValueState(); + } + } + + /** get new list state proxy */ + protected ListStateStoreManagerProxy newListStateProxy( + ListStateDescriptor stateDescriptor) { + return new ListStateStoreManagerProxy<>(this, stateDescriptor); + } + + public ListState getListState(ListStateDescriptor stateDescriptor) { + String desc = stateDescriptor.getIdentify(); + if (listManagerProxyHashMap.containsKey(desc)) { + ListStateStoreManagerProxy listStateProxy = listManagerProxyHashMap.get(desc); + return listStateProxy.getListState(); + } else { + ListStateStoreManagerProxy listStateProxy = newListStateProxy(stateDescriptor); + listManagerProxyHashMap.put(desc, listStateProxy); + return listStateProxy.getListState(); + } + } + + /** get map state proxy */ + protected MapStateStoreManagerProxy newMapStateProxy( + MapStateDescriptor stateDescriptor) { + return new MapStateStoreManagerProxy<>(this, stateDescriptor); + } + + public MapState getMapState(MapStateDescriptor stateDescriptor) { + String desc = stateDescriptor.getIdentify(); + if (mapManagerProxyHashMap.containsKey(desc)) { + MapStateStoreManagerProxy mapStateProxy = mapManagerProxyHashMap.get(desc); + return mapStateProxy.getMapState(); + } else { + MapStateStoreManagerProxy mapStateProxy = newMapStateProxy(stateDescriptor); + mapManagerProxyHashMap.put(desc, mapStateProxy); + return mapStateProxy.getMapState(); + } + } + + @Override + public void setCurrentKey(Object currentKey) { + super.keyGroupIndex = + KeyGroupAssignment.assignKeyGroupIndexForKey(currentKey, numberOfKeyGroups); + super.currentKey = currentKey; + } + + public int getNumberOfKeyGroups() { + return numberOfKeyGroups; + } + + public KeyGroup getKeyGroup() { + return keyGroup; + } + + public void close() { + for (ValueStateStoreManagerProxy proxy : valueManagerProxyHashMap.values()) { + proxy.close(); + } + for (ListStateStoreManagerProxy proxy : listManagerProxyHashMap.values()) { + proxy.close(); + } + for (MapStateStoreManagerProxy proxy : mapManagerProxyHashMap.values()) { + proxy.close(); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/OperatorStateBackend.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/OperatorStateBackend.java new file mode 100644 index 00000000..c41ad788 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/OperatorStateBackend.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.impl.OperatorStateImpl; +import io.ray.streaming.state.keystate.state.proxy.ListStateStoreManagerProxy; + +/** OperatorState manager for getting split or union list state. */ +public class OperatorStateBackend extends AbstractKeyStateBackend { + + public OperatorStateBackend(AbstractStateBackend backend) { + super(backend); + } + + @Override + public void setCurrentKey(Object currentKey) { + super.currentKey = currentKey; + } + + protected ListStateStoreManagerProxy newListStateStoreManagerProxy( + ListStateDescriptor stateDescriptor) { + return new ListStateStoreManagerProxy<>(this, stateDescriptor); + } + + /** get spitted List for different operator instance. */ + public ListState getSplitListState(ListStateDescriptor stateDescriptor) { + String desc = stateDescriptor.getIdentify(); + if (listManagerProxyHashMap.containsKey(desc)) { + ListStateStoreManagerProxy listStateProxy = listManagerProxyHashMap.get(desc); + return listStateProxy.getListState(); + } else { + ListStateStoreManagerProxy listStateProxy = newListStateStoreManagerProxy(stateDescriptor); + listManagerProxyHashMap.put(desc, listStateProxy); + ((OperatorStateImpl) (listStateProxy.getListState())).setSplit(true); + return listStateProxy.getListState(); + } + } + + /** get a union List for different operator instance. */ + public ListState getUnionListState(ListStateDescriptor stateDescriptor) { + String desc = stateDescriptor.getIdentify(); + if (listManagerProxyHashMap.containsKey(desc)) { + ListStateStoreManagerProxy listStateProxy = listManagerProxyHashMap.get(desc); + return listStateProxy.getListState(); + } else { + ListStateStoreManagerProxy listStateProxy = newListStateStoreManagerProxy(stateDescriptor); + listManagerProxyHashMap.put(desc, listStateProxy); + ((OperatorStateImpl) (listStateProxy.getListState())).init(); + return listStateProxy.getListState(); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateBackendBuilder.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateBackendBuilder.java new file mode 100644 index 00000000..91c3e2be --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateBackendBuilder.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.config.ConfigKey; +import java.util.Map; + +/** State Backend Builder. */ +public class StateBackendBuilder { + + private static AbstractStateBackend getStateBackend( + Map config, BackendType type) { + switch (type) { + case MEMORY: + return new MemoryStateBackend(config); + default: + throw new RuntimeException(type.name() + " not supported"); + } + } + + public static AbstractStateBackend buildStateBackend(Map config) { + BackendType type; + if (config == null) { + type = BackendType.MEMORY; + } else { + type = BackendType.getEnum(config.get(ConfigKey.STATE_BACKEND_TYPE)); + } + + return getStateBackend(config, type); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateStrategy.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateStrategy.java new file mode 100644 index 00000000..cf81bac3 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/StateStrategy.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +/** This class describe State Saving Model. */ +public enum StateStrategy { + /** save two version together in case of rollback. */ + DUAL_VERSION, + + /** for storage supporting mvcc, we save only current version. */ + SINGLE_VERSION; + + public static StateStrategy getEnum(String value) { + for (StateStrategy v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } + } + throw new IllegalArgumentException(value + " strategy is not supported"); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/impl/MemoryStateBackend.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/impl/MemoryStateBackend.java new file mode 100644 index 00000000..241f00dc --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/backend/impl/MemoryStateBackend.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend.impl; + +import io.ray.streaming.state.backend.AbstractStateBackend; +import io.ray.streaming.state.serialization.KeyMapStoreSerializer; +import io.ray.streaming.state.store.KeyMapStore; +import io.ray.streaming.state.store.KeyValueStore; +import io.ray.streaming.state.store.impl.MemoryKeyMapStore; +import io.ray.streaming.state.store.impl.MemoryKeyValueStore; +import java.util.Map; + +/** MemoryStateBackend. Supporting memory store. */ +public class MemoryStateBackend extends AbstractStateBackend { + + public MemoryStateBackend(Map config) { + super(config); + } + + @Override + public KeyValueStore getKeyValueStore(String tableName) { + return new MemoryKeyValueStore<>(); + } + + @Override + public KeyMapStore getKeyMapStore(String tableName) { + return new MemoryKeyMapStore<>(); + } + + @Override + public KeyMapStore getKeyMapStore( + String tableName, KeyMapStoreSerializer keyMapStoreSerializer) { + return new MemoryKeyMapStore<>(); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigHelper.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigHelper.java new file mode 100644 index 00000000..92a80da5 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigHelper.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.config; + +import java.util.Map; + +/** Config Helper figure out the config info. Todo replace this to config module. */ +public class ConfigHelper { + + public static int getIntegerOrDefault(Map config, String configKey, int defaultValue) { + if (config.containsKey(configKey)) { + return Integer.valueOf(String.valueOf(config.get(configKey))); + } else { + return defaultValue; + } + } + + public static String getStringOrDefault(Map config, String configKey, String defaultValue) { + if (config.containsKey(configKey)) { + return String.valueOf(config.get(configKey)); + } else { + return defaultValue; + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigKey.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigKey.java new file mode 100644 index 00000000..52704f02 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/config/ConfigKey.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.config; + +import java.util.Map; + +/** state config keys. Todo replace this to config module. */ +public final class ConfigKey { + + /** backend */ + public static final String STATE_BACKEND_TYPE = "state.backend.type"; + + public static final String STATE_TABLE_NAME = "state.table.name"; + public static final String STATE_STRATEGY_MODE = "state.strategy.mode"; + public static final String NUMBER_PER_CHECKPOINT = "number.per.checkpoint"; + public static final String JOB_MAX_PARALLEL = "job.max.parallel"; + public static final String DELIMITER = "\u0001\u0008"; // for String delimiter + + private ConfigKey() { + throw new AssertionError(); + } + + public static String getStateStrategyEnum(Map config) { + return ConfigHelper.getStringOrDefault(config, STATE_STRATEGY_MODE, "DUAL_VERSION"); + } + + public static String getBackendType(Map config) { + return ConfigHelper.getStringOrDefault(config, STATE_BACKEND_TYPE, "MEMORY"); + } + + public static int getNumberPerCheckpoint(Map config) { + return ConfigHelper.getIntegerOrDefault(config, NUMBER_PER_CHECKPOINT, 5); + } + + public static String getStateTableName(Map config) { + return ConfigHelper.getStringOrDefault(config, STATE_TABLE_NAME, "table"); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroup.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroup.java new file mode 100644 index 00000000..d5132365 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroup.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate; + +import com.google.common.base.Preconditions; +import java.io.Serializable; + +/** + * This class defines key-groups. Key-groups is the key space in a job, which is partitioned for + * keyed state processing in state backend. The boundaries of the key-group are inclusive. + */ +public class KeyGroup implements Serializable { + + private final int startIndex; + private final int endIndex; + + /** + * Defines the range [startIndex, endIndex] + * + * @param startIndex start of the range (inclusive) + * @param endIndex end of the range (inclusive) + */ + public KeyGroup(int startIndex, int endIndex) { + Preconditions.checkArgument(startIndex >= 0 && startIndex <= endIndex); + this.startIndex = startIndex; + this.endIndex = endIndex; + Preconditions.checkArgument(size() >= 0, "overflow detected."); + } + + /** Returns The number of key-group in the range */ + public int size() { + return 1 + endIndex - startIndex; + } + + public int getStartIndex() { + return startIndex; + } + + public int getEndIndex() { + return endIndex; + } + + @Override + public String toString() { + return "KeyGroup{" + "startIndex=" + startIndex + ", endIndex=" + endIndex + '}'; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroupAssignment.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroupAssignment.java new file mode 100644 index 00000000..921ea859 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/KeyGroupAssignment.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** This class defines key-group assignment algorithm。 */ +public final class KeyGroupAssignment { + + /** + * Computes the range of key-groups that are assigned for a given operator instance. + * + * @param maxParallelism Maximal parallelism of the job. + * @param parallelism Parallelism for the job. <= maxParallelism. + * @param index index of the operator instance. + */ + public static KeyGroup getKeyGroup(int maxParallelism, int parallelism, int index) { + Preconditions.checkArgument( + maxParallelism >= parallelism, + "Maximum parallelism (%s) must not be smaller than parallelism(%s)", + maxParallelism, + parallelism); + + int start = index == 0 ? 0 : ((index * maxParallelism - 1) / parallelism) + 1; + int end = ((index + 1) * maxParallelism - 1) / parallelism; + return new KeyGroup(start, end); + } + + /** + * Assigning the key to a key-group index. + * + * @param key the key to assign. + * @param maxParallelism the maximum parallelism. + * @return the key-group index to which the given key is assigned. + */ + public static int assignKeyGroupIndexForKey(Object key, int maxParallelism) { + return Math.abs(key.hashCode() % maxParallelism); + } + + public static Map> computeKeyGroupToTask( + int maxParallelism, List targetTasks) { + Map> keyGroupToTask = new ConcurrentHashMap<>(); + for (int index = 0; index < targetTasks.size(); index++) { + KeyGroup taskKeyGroup = getKeyGroup(maxParallelism, targetTasks.size(), index); + for (int groupId = taskKeyGroup.getStartIndex(); + groupId <= taskKeyGroup.getEndIndex(); + groupId++) { + keyGroupToTask.put(groupId, ImmutableList.of(targetTasks.get(index))); + } + } + return keyGroupToTask; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/AbstractStateDescriptor.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/AbstractStateDescriptor.java new file mode 100644 index 00000000..950f922c --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/AbstractStateDescriptor.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import com.google.common.base.Preconditions; +import io.ray.streaming.state.keystate.state.State; + +/** This class defines basic data structures of StateDescriptor. */ +public abstract class AbstractStateDescriptor { + + private final String name; + private String tableName; + private Class type; + + protected AbstractStateDescriptor(String name, Class type) { + this.name = name; + this.type = type; + } + + public String getName() { + return name; + } + + public Class getType() { + return type; + } + + protected Class setType(Class type) { + return type; + } + + public String getTableName() { + return tableName; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public abstract StateType getStateType(); + + public String getIdentify() { + Preconditions.checkArgument(this.tableName != null, "table name must not be null."); + Preconditions.checkArgument(this.name != null, "table name must not be null."); + return this.name; + } + + @Override + public String toString() { + return "AbstractStateDescriptor{" + + "tableName='" + + tableName + + '\'' + + ", name='" + + name + + '\'' + + ", type=" + + type + + '}'; + } + + public enum StateType { + /** value state */ + VALUE, + + /** list state */ + LIST, + + /** map state */ + MAP + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ListStateDescriptor.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ListStateDescriptor.java new file mode 100644 index 00000000..62c7a249 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ListStateDescriptor.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import static io.ray.streaming.state.config.ConfigKey.DELIMITER; + +import io.ray.streaming.state.keystate.state.ListState; + +/** ListStateDescriptor. */ +public class ListStateDescriptor extends AbstractStateDescriptor, T> { + + private final boolean isOperatorList; + private int index; + private int partitionNum; + + private ListStateDescriptor(String name, Class type, boolean isOperatorList) { + super(name, type); + this.isOperatorList = isOperatorList; + } + + public static ListStateDescriptor build(String name, Class type) { + return build(name, type, false); + } + + public static ListStateDescriptor build( + String name, Class type, boolean isOperatorList) { + return new ListStateDescriptor<>(name, type, isOperatorList); + } + + public boolean isOperatorList() { + return isOperatorList; + } + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public int getPartitionNumber() { + return partitionNum; + } + + public void setPartitionNumber(int number) { + this.partitionNum = number; + } + + @Override + public StateType getStateType() { + return StateType.LIST; + } + + @Override + public String getIdentify() { + if (isOperatorList) { + return String.format( + "%s%s%d%s%d", super.getIdentify(), DELIMITER, partitionNum, DELIMITER, index); + } else { + return super.getIdentify(); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/MapStateDescriptor.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/MapStateDescriptor.java new file mode 100644 index 00000000..b6c3c382 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/MapStateDescriptor.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import io.ray.streaming.state.keystate.state.MapState; +import java.util.Map; + +/** MapStateDescriptor. */ +public class MapStateDescriptor extends AbstractStateDescriptor, Map> { + + public MapStateDescriptor(String name, Class keyType, Class valueType) { + super(name, null); + // TODO: use the types to help serde + } + + public static MapStateDescriptor build( + String name, Class keyType, Class valueType) { + return new MapStateDescriptor<>(name, keyType, valueType); + } + + @Override + public StateType getStateType() { + return StateType.MAP; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptor.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptor.java new file mode 100644 index 00000000..79c7df47 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptor.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import io.ray.streaming.state.keystate.state.ValueState; + +/** ValueStateDescriptor. */ +public class ValueStateDescriptor extends AbstractStateDescriptor, T> { + + private final T defaultValue; + + public ValueStateDescriptor(String name, Class type, T defaultValue) { + super(name, type); + this.defaultValue = defaultValue; + } + + public static ValueStateDescriptor build(String name, Class type, T defaultValue) { + return new ValueStateDescriptor<>(name, type, defaultValue); + } + + public T getDefaultValue() { + return defaultValue; + } + + @Override + public StateType getStateType() { + return StateType.VALUE; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ListState.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ListState.java new file mode 100644 index 00000000..966ee7ab --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ListState.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state; + +import java.util.List; + +/** ListState interface. */ +public interface ListState extends UnaryState> { + + /** + * add the value to list + * + * @param value the new value + */ + void add(T value); + + /** + * update list state + * + * @param list the new value + */ + void update(List list); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/MapState.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/MapState.java new file mode 100644 index 00000000..a632d21d --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/MapState.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state; + +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; + +/** MapState interface. */ +public interface MapState extends UnaryState> { + + /** + * Returns the current value associated with the given key. + * + * @param key The key of the mapping + * @return The value of the mapping with the given key + */ + V get(K key); + + /** + * Associates a new value with the given key. + * + * @param key The key of the mapping + * @param value The new value of the mapping + */ + void put(K key, V value); + + /** + * Resets the state value. + * + * @param map The mappings for reset in this state + */ + void update(Map map); + + /** + * Copies all of the mappings from the given map into the state. + * + * @param map The mappings to be stored in this state + */ + void putAll(Map map); + + /** + * Deletes the mapping of the given key. + * + * @param key The key of the mapping + */ + void remove(K key); + + /** + * Returns whether there exists the given mapping. + * + * @param key The key of the mapping + * @return True if there exists a mapping whose key equals to the given key + */ + default boolean contains(K key) { + return get().containsKey(key); + } + + /** + * Returns all the mappings in the state + * + * @return An iterable view of all the key-value pairs in the state. + */ + default Iterable> entries() { + return get().entrySet(); + } + + /** + * Returns all the keys in the state + * + * @return An iterable view of all the keys in the state. + */ + default Iterable keys() { + return get().keySet(); + } + + /** + * Returns all the values in the state. + * + * @return An iterable view of all the values in the state. + */ + default Iterable values() { + return get().values(); + } + + /** + * Iterates over all the mappings in the state. + * + * @return An iterator over all the mappings in the state + */ + default Iterator> iterator() { + return get().entrySet().iterator(); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/State.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/State.java new file mode 100644 index 00000000..3e870469 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/State.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state; + +/** State interface. */ +public interface State { + + /** set current key of the state */ + void setCurrentKey(Object currentKey); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/UnaryState.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/UnaryState.java new file mode 100644 index 00000000..637b5731 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/UnaryState.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state; + +/** one value per state. */ +public interface UnaryState extends State { + + /** + * get the value in state + * + * @return the value in state + */ + O get(); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ValueState.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ValueState.java new file mode 100644 index 00000000..3c1b0637 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/ValueState.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state; + +/** ValueState interface. */ +public interface ValueState extends UnaryState { + + /** + * update the value + * + * @param value the new value + */ + void update(T value); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ListStateImpl.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ListStateImpl.java new file mode 100644 index 00000000..7dce9642 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ListStateImpl.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import java.util.ArrayList; +import java.util.List; + +/** ListState implementation. */ +public class ListStateImpl implements ListState { + + private final StateHelper> helper; + + public ListStateImpl(ListStateDescriptor descriptor, AbstractKeyStateBackend backend) { + this.helper = new StateHelper<>(backend, descriptor); + } + + @Override + public List get() { + List list = helper.get(); + if (list == null) { + list = new ArrayList<>(); + } + return list; + } + + @Override + public void add(V value) { + List list = helper.get(); + if (list == null) { + list = new ArrayList<>(); + } + list.add(value); + helper.put(list); + } + + @Override + public void update(List list) { + if (list == null) { + list = new ArrayList<>(); + } + helper.put(list); + } + + @Override + public void setCurrentKey(Object currentKey) { + helper.setCurrentKey(currentKey); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/MapStateImpl.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/MapStateImpl.java new file mode 100644 index 00000000..d31977dc --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/MapStateImpl.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.state.MapState; +import java.util.HashMap; +import java.util.Map; + +/** MapState implementation. */ +public class MapStateImpl implements MapState { + + private final StateHelper> helper; + + public MapStateImpl(MapStateDescriptor descriptor, KeyStateBackend backend) { + this.helper = new StateHelper<>(backend, descriptor); + } + + @Override + public Map get() { + Map map = helper.get(); + if (map == null) { + map = new HashMap<>(); + } + return map; + } + + @Override + public V get(K key) { + Map map = get(); + return map.get(key); + } + + @Override + public void put(K key, V value) { + Map map = get(); + + map.put(key, value); + helper.put(map); + } + + @Override + public void update(Map map) { + if (map == null) { + map = new HashMap<>(); + } + helper.put(map); + } + + @Override + public void putAll(Map newMap) { + Map map = get(); + + map.putAll(newMap); + helper.put(map); + } + + @Override + public void remove(K key) { + Map map = get(); + + map.remove(key); + helper.put(map); + } + + /** set current key of the state */ + @Override + public void setCurrentKey(Object currentKey) { + helper.setCurrentKey(currentKey); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImpl.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImpl.java new file mode 100644 index 00000000..ee298705 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImpl.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import static io.ray.streaming.state.config.ConfigKey.DELIMITER; + +import com.google.common.base.Preconditions; +import io.ray.streaming.state.PartitionRecord; +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * This class defines the implementation of operator state. When the state is initialized, we must + * scan the whole table. And if the state type is splitList, all the records must be spitted. + */ +public class OperatorStateImpl implements ListState { + + private final ListStateDescriptor descriptor; + private StateHelper>> helper; + private List> allList; + + private AtomicBoolean hasInit; + private boolean isSplit; + + public OperatorStateImpl(ListStateDescriptor descriptor, AbstractKeyStateBackend backend) { + this.descriptor = descriptor; + this.helper = new StateHelper<>(backend, descriptor); + this.isSplit = false; + this.hasInit = new AtomicBoolean(false); + this.allList = new ArrayList<>(); + } + + private void splitList() { + // fetch target list and save + List> list = new ArrayList<>(); + int step = descriptor.getPartitionNumber(); + Preconditions.checkArgument(step > 0); + + for (int round = 0; round * step <= allList.size(); round++) { + int m = round * step + descriptor.getIndex(); + if (m < allList.size()) { + PartitionRecord tmp = allList.get(m); + tmp.setPartitionID(descriptor.getPartitionNumber()); + list.add(tmp); + } + } + helper.put(list, getStateKey()); + allList.clear(); + } + + private void scan() { + int partitionNum = -1; + int index = 0; + while (true) { + List> list = + helper.getBackend().get(descriptor, getKey(descriptor.getIdentify(), index)); + if (list != null && !list.isEmpty()) { + partitionNum = list.get(0).getPartitionID(); + allList.addAll(list); + } + if (++index >= partitionNum) { + break; + } + } + } + + public void init() { + scan(); + + if (isSplit) { + splitList(); + } + } + + private String getKey(String descName, int index) { + String[] stringList = descName.split(DELIMITER); + return String.format("%s%s%s%s%d", stringList[0], DELIMITER, stringList[1], DELIMITER, index); + } + + protected String getStateKey() { + return getKey(this.descriptor.getIdentify(), this.descriptor.getIndex()); + } + + @Override + public void setCurrentKey(Object currentKey) { + throw new UnsupportedOperationException("OperatorState cannot set current key"); + } + + @Override + public List get() { + if (!hasInit.getAndSet(true)) { + init(); + } + List> prList = helper.get(getStateKey()); + List list = new ArrayList<>(); + for (PartitionRecord pr : prList) { + list.add(pr.getValue()); + } + return list; + } + + @Override + public void add(V value) { + if (!hasInit.getAndSet(true)) { + init(); + } + List> list = helper.get(getStateKey()); + if (list == null) { + list = new ArrayList<>(); + } + list.add(new PartitionRecord<>(descriptor.getPartitionNumber(), value)); + helper.put(list, getStateKey()); + } + + @Override + public void update(List list) { + List> prList = new ArrayList<>(); + if (list != null) { + for (V value : list) { + prList.add(new PartitionRecord<>(descriptor.getPartitionNumber(), value)); + } + } + helper.put(prList); + } + + public void setSplit(boolean split) { + this.isSplit = split; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/StateHelper.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/StateHelper.java new file mode 100644 index 00000000..9a6fe0f0 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/StateHelper.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import com.google.common.base.Preconditions; +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor; + +/** State Helper Class. */ +public class StateHelper { + + private final AbstractKeyStateBackend backend; + private final AbstractStateDescriptor descriptor; + + public StateHelper(AbstractKeyStateBackend backend, AbstractStateDescriptor descriptor) { + this.backend = backend; + this.descriptor = descriptor; + } + + protected String getStateKey(String descName) { + Preconditions.checkNotNull(backend, "KeyedBackend must not be null"); + Preconditions.checkNotNull(backend.getCurrentKey(), "currentKey must not be null"); + return this.backend.getBackend().getStateKey(descName, backend.getCurrentKey().toString()); + } + + public void put(T value, String key) { + backend.put(descriptor, key, value); + } + + public void put(T value) { + put(value, getStateKey(descriptor.getIdentify())); + } + + public T get(String key) { + return backend.get(descriptor, key); + } + + public T get() { + return get(getStateKey(descriptor.getIdentify())); + } + + public void setCurrentKey(Object currentKey) { + Preconditions.checkNotNull(backend, "KeyedBackend must not be null"); + this.backend.setCurrentKey(currentKey); + } + + public void setKeyGroupIndex(int keyGroupIndex) { + this.backend.setKeyGroupIndex(keyGroupIndex); + } + + public void resetKeyGroupIndex() { + this.backend.setKeyGroupIndex(-1); + } + + public AbstractStateDescriptor getDescriptor() { + return descriptor; + } + + public AbstractKeyStateBackend getBackend() { + return backend; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ValueStateImpl.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ValueStateImpl.java new file mode 100644 index 00000000..8343b214 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/impl/ValueStateImpl.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ValueState; + +/** ValueState implementation. */ +public class ValueStateImpl implements ValueState { + + private final StateHelper helper; + + public ValueStateImpl(ValueStateDescriptor descriptor, KeyStateBackend backend) { + this.helper = new StateHelper<>(backend, descriptor); + } + + @Override + public void update(T value) { + helper.put(value); + } + + @Override + public T get() { + T value = helper.get(); + if (null == value) { + return ((ValueStateDescriptor) helper.getDescriptor()).getDefaultValue(); + } else { + return value; + } + } + + /** set current key of the state */ + @Override + public void setCurrentKey(Object currentKey) { + helper.setCurrentKey(currentKey); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerProxy.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerProxy.java new file mode 100644 index 00000000..ac10f9d1 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerProxy.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.KeyValueState; +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.impl.ListStateImpl; +import io.ray.streaming.state.keystate.state.impl.OperatorStateImpl; +import io.ray.streaming.state.strategy.StateStoreManagerProxy; +import java.util.List; + +/** This class defines ListState Wrapper, connecting state and backend. */ +public class ListStateStoreManagerProxy extends StateStoreManagerProxy> + implements KeyValueState> { + + private final ListState listState; + + public ListStateStoreManagerProxy( + AbstractKeyStateBackend keyStateBackend, ListStateDescriptor stateDescriptor) { + super(keyStateBackend, stateDescriptor); + if (stateDescriptor.isOperatorList()) { + this.listState = new OperatorStateImpl<>(stateDescriptor, keyStateBackend); + } else { + this.listState = new ListStateImpl<>(stateDescriptor, keyStateBackend); + } + } + + public ListState getListState() { + return this.listState; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerProxy.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerProxy.java new file mode 100644 index 00000000..ab872739 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerProxy.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.KeyValueState; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.impl.MapStateImpl; +import io.ray.streaming.state.strategy.StateStoreManagerProxy; +import java.util.Map; + +/** This class defines MapState Wrapper, connecting state and backend. */ +public class MapStateStoreManagerProxy extends StateStoreManagerProxy> + implements KeyValueState> { + + private final MapStateImpl mapState; + + public MapStateStoreManagerProxy( + KeyStateBackend keyStateBackend, MapStateDescriptor stateDescriptor) { + super(keyStateBackend, stateDescriptor); + this.mapState = new MapStateImpl<>(stateDescriptor, keyStateBackend); + } + + public MapState getMapState() { + return this.mapState; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerProxy.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerProxy.java new file mode 100644 index 00000000..ed68bcdc --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerProxy.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.KeyValueState; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ValueState; +import io.ray.streaming.state.keystate.state.impl.ValueStateImpl; +import io.ray.streaming.state.strategy.StateStoreManagerProxy; + +/** This class defines ValueState Wrapper, connecting state and backend. */ +public class ValueStateStoreManagerProxy extends StateStoreManagerProxy + implements KeyValueState { + + private final ValueStateImpl valueState; + + public ValueStateStoreManagerProxy( + KeyStateBackend keyStateBackend, ValueStateDescriptor stateDescriptor) { + super(keyStateBackend, stateDescriptor); + this.valueState = new ValueStateImpl<>(stateDescriptor, keyStateBackend); + } + + public ValueState getValueState() { + return this.valueState; + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyMapStoreSerializer.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyMapStoreSerializer.java new file mode 100644 index 00000000..314289af --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyMapStoreSerializer.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization; + +/** Key Map Serialization and Deserialization. */ +public interface KeyMapStoreSerializer { + + byte[] serializeKey(K key); + + byte[] serializeUKey(S uk); + + S deserializeUKey(byte[] ukArray); + + byte[] serializeUValue(T uv); + + T deserializeUValue(byte[] uvArray); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyValueStoreSerialization.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyValueStoreSerialization.java new file mode 100644 index 00000000..b0861356 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/KeyValueStoreSerialization.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization; + +/** Key Value Serialization and Deserialization. */ +public interface KeyValueStoreSerialization { + + byte[] serializeKey(K key); + + byte[] serializeValue(V value); + + V deserializeValue(byte[] valueArray); +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/Serializer.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/Serializer.java new file mode 100644 index 00000000..eaae2a0c --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/Serializer.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization; + +import org.nustaq.serialization.FSTConfiguration; + +/** fst wrapper. */ +public class Serializer { + + private static final ThreadLocal conf = + ThreadLocal.withInitial(FSTConfiguration::createDefaultConfiguration); + + public static byte[] object2Bytes(Object value) { + return conf.get().asByteArray(value); + } + + public static Object bytes2Object(byte[] buffer) { + return conf.get().asObject(buffer); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/AbstractSerialization.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/AbstractSerialization.java new file mode 100644 index 00000000..c2d2ec05 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/AbstractSerialization.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization.impl; + +import com.google.common.hash.Hashing; +import io.ray.streaming.state.StateException; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** AbstractSerialization. Generate row key. */ +public abstract class AbstractSerialization { + + private static final Logger LOG = LoggerFactory.getLogger(AbstractSerialization.class); + + public String generateRowKeyPrefix(String key) { + if (StringUtils.isNotEmpty(key)) { + String md5 = Hashing.md5().hashUnencodedChars(key).toString(); + if ("".equals(md5)) { + throw new StateException("Invalid value to md5:" + key); + } + return StringUtils.substring(md5, 0, 4) + ":" + key; + } else { + LOG.warn("key is empty"); + return key; + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyMapStoreSerializer.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyMapStoreSerializer.java new file mode 100644 index 00000000..fa6dfb6b --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyMapStoreSerializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization.impl; + +import io.ray.streaming.state.serialization.KeyMapStoreSerializer; +import io.ray.streaming.state.serialization.Serializer; + +/** Default Key Map Serialization and Deserialization. */ +public class DefaultKeyMapStoreSerializer extends AbstractSerialization + implements KeyMapStoreSerializer { + + @Override + public byte[] serializeKey(K key) { + String keyWithPrefix = generateRowKeyPrefix(key.toString()); + return keyWithPrefix.getBytes(); + } + + @Override + public byte[] serializeUKey(S uk) { + return Serializer.object2Bytes(uk); + } + + @Override + public S deserializeUKey(byte[] ukArray) { + return (S) Serializer.bytes2Object(ukArray); + } + + @Override + public byte[] serializeUValue(T uv) { + return Serializer.object2Bytes(uv); + } + + @Override + public T deserializeUValue(byte[] uvArray) { + return (T) Serializer.bytes2Object(uvArray); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyValueStoreSerialization.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyValueStoreSerialization.java new file mode 100644 index 00000000..9a9c7f63 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/serialization/impl/DefaultKeyValueStoreSerialization.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.serialization.impl; + +import io.ray.streaming.state.serialization.KeyValueStoreSerialization; +import io.ray.streaming.state.serialization.Serializer; + +/** KV Store Serialization and Deserialization. */ +public class DefaultKeyValueStoreSerialization extends AbstractSerialization + implements KeyValueStoreSerialization { + + @Override + public byte[] serializeKey(K key) { + String keyWithPrefix = generateRowKeyPrefix(key.toString()); + return keyWithPrefix.getBytes(); + } + + @Override + public byte[] serializeValue(V value) { + return Serializer.object2Bytes(value); + } + + @Override + public V deserializeValue(byte[] valueArray) { + return (V) Serializer.bytes2Object(valueArray); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyMapStore.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyMapStore.java new file mode 100644 index 00000000..21981f81 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyMapStore.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.store; + +import java.io.IOException; +import java.util.Map; + +/** Key Map Store interface. */ +public interface KeyMapStore extends KeyValueStore> { + + /** put sub key value into the store incrementally. */ + void put(K key, S subKey, T value) throws IOException; + + /** get subValue from store. */ + T get(K key, S subKey) throws IOException; +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyValueStore.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyValueStore.java new file mode 100644 index 00000000..2e04f9d0 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/KeyValueStore.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.store; + +import java.io.IOException; +import java.io.Serializable; + +/** Key Value Store interface. */ +public interface KeyValueStore extends Serializable { + + /** put key value into store. */ + void put(K key, V value) throws IOException; + + /** get value from store. */ + V get(K key) throws IOException; + + /** remove key in the store. */ + void remove(K key) throws IOException; + + /** flush to disk. */ + void flush() throws IOException; + + /** clear all cache. */ + void clearCache(); + + /** close the store. */ + void close() throws IOException; +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyMapStore.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyMapStore.java new file mode 100644 index 00000000..5182aa22 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyMapStore.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.store.impl; + +import com.google.common.collect.Maps; +import io.ray.streaming.state.store.KeyMapStore; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** Memory Key Map Store. */ +public class MemoryKeyMapStore implements KeyMapStore { + + private Map> memoryStore; + + public MemoryKeyMapStore() { + this.memoryStore = Maps.newConcurrentMap(); + } + + @Override + public void put(K key, Map value) throws IOException { + this.memoryStore.put(key, value); + } + + @Override + public void put(K key, S subKey, T value) throws IOException { + if (memoryStore.containsKey(key)) { + memoryStore.get(key).put(subKey, value); + } else { + Map map = new HashMap<>(); + map.put(subKey, value); + memoryStore.put(key, map); + } + } + + @Override + public Map get(K key) throws IOException { + return this.memoryStore.get(key); + } + + @Override + public T get(K key, S subKey) throws IOException { + if (memoryStore.containsKey(key)) { + return memoryStore.get(key).get(subKey); + } + return null; + } + + @Override + public void remove(K key) throws IOException { + this.memoryStore.remove(key); + } + + @Override + public void flush() throws IOException {} + + @Override + public void clearCache() {} + + @Override + public void close() throws IOException { + if (memoryStore != null) { + memoryStore.clear(); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyValueStore.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyValueStore.java new file mode 100644 index 00000000..4e4113ff --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/store/impl/MemoryKeyValueStore.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.store.impl; + +import com.google.common.collect.Maps; +import io.ray.streaming.state.store.KeyValueStore; +import java.io.IOException; +import java.util.Map; + +/** Memory Key Value Store. */ +public class MemoryKeyValueStore implements KeyValueStore { + + private Map memoryStore; + + public MemoryKeyValueStore() { + this.memoryStore = Maps.newConcurrentMap(); + } + + @Override + public void put(K key, V value) throws IOException { + this.memoryStore.put(key, value); + } + + @Override + public V get(K key) throws IOException { + return this.memoryStore.get(key); + } + + @Override + public void remove(K key) throws IOException { + this.memoryStore.remove(key); + } + + @Override + public void flush() throws IOException {} + + @Override + public void clearCache() {} + + @Override + public void close() throws IOException { + if (memoryStore != null) { + memoryStore.clear(); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/AbstractStateStoreManager.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/AbstractStateStoreManager.java new file mode 100644 index 00000000..3ee70cfd --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/AbstractStateStoreManager.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import io.ray.streaming.state.StateException; +import io.ray.streaming.state.StateStoreManager; +import io.ray.streaming.state.StorageRecord; +import io.ray.streaming.state.serialization.Serializer; +import io.ray.streaming.state.store.KeyValueStore; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class defines the StoreManager Abstract class. We use three layer to store the state, + * frontStore, middleStore and keyValueStore(remote). + */ +public abstract class AbstractStateStoreManager implements StateStoreManager { + + /** read-write */ + protected Map> frontStore = new ConcurrentHashMap<>(); + + /** remote-storage */ + protected KeyValueStore> kvStore; + + /** read-only */ + protected Map> middleStore = new ConcurrentHashMap<>(); + + protected int keyGroupIndex = -1; + + public AbstractStateStoreManager(KeyValueStore> backStore) { + kvStore = backStore; + } + + public byte[] toBytes(StorageRecord storageRecord) { + return Serializer.object2Bytes(storageRecord); + } + + public StorageRecord toStorageRecord(byte[] data) { + return (StorageRecord) Serializer.bytes2Object(data); + } + + public abstract V get(long checkpointId, String key); + + public void put(long checkpointId, String k, V v) { + frontStore.put(k, new StorageRecord<>(checkpointId, v)); + } + + @Override + public void ackCommit(long checkpointId, long timeStamp) { + ackCommit(checkpointId); + } + + public abstract void ackCommit(long checkpointId); + + public void setKeyGroupIndex(int keyGroupIndex) { + this.keyGroupIndex = keyGroupIndex; + } + + public void close() { + frontStore.clear(); + middleStore.clear(); + if (kvStore != null) { + kvStore.clearCache(); + try { + kvStore.close(); + } catch (IOException e) { + throw new StateException(e); + } + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/DualStateStoreManager.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/DualStateStoreManager.java new file mode 100644 index 00000000..b096781d --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/DualStateStoreManager.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import com.google.common.primitives.Longs; +import io.ray.streaming.state.StateException; +import io.ray.streaming.state.StorageRecord; +import io.ray.streaming.state.store.KeyValueStore; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** This class define the checkpoint store strategy, which saves two-version data once. */ +public class DualStateStoreManager extends AbstractStateStoreManager { + + private static final Logger LOG = LoggerFactory.getLogger(DualStateStoreManager.class); + + public DualStateStoreManager(KeyValueStore> backStore) { + super(backStore); + } + + @Override + public void finish(long checkpointId) { + LOG.info("do finish checkpointId:{}", checkpointId); + Map cpStore = new HashMap<>(); + for (Entry> entry : frontStore.entrySet()) { + String key = entry.getKey(); + StorageRecord value = entry.getValue(); + cpStore.put(key, toBytes(value)); + } + middleStore.put(checkpointId, cpStore); + frontStore.clear(); + } + + @Override + public void commit(long checkpointId) { + try { + LOG.info("do commit checkpointId:{}", checkpointId); + Map cpStore = middleStore.get(checkpointId); + if (cpStore == null) { + throw new StateException("why cp store is null"); + } + for (Entry entry : cpStore.entrySet()) { + String key = entry.getKey(); + byte[] value = entry.getValue(); + + /** + * 2 is specific key in kv store and indicates that new value should be stored with this key + * after overwriting old value in key 1. i.e. + * + *

-2 -1 1 2 k1 6 5 a b k2 9 7 d e + * + *

k1's value for checkpoint 5 is a, and b for checkpoint 6. + */ + Map remoteData = super.kvStore.get(key); + if (remoteData == null || remoteData.size() == 0) { + remoteData = new HashMap<>(); + remoteData.put(2L, value); + remoteData.put(-2L, Longs.toByteArray(checkpointId)); + } else { + long oldBatchId = Longs.fromByteArray(remoteData.get(-2L)); + if (oldBatchId < checkpointId) { + // move the old data + remoteData.put(1L, remoteData.get(2L)); + remoteData.put(-1L, remoteData.get(-2L)); + } + + // put the new data here + remoteData.put(2L, value); + remoteData.put(-2L, Longs.toByteArray(checkpointId)); + } + super.kvStore.put(key, remoteData); + } + super.kvStore.flush(); + } catch (Exception e) { + LOG.error(e.getMessage(), e); + throw new StateException(e); + } + } + + @Override + public void rollBack(long checkpointId) { + LOG.info("do rollBack checkpointId:{}", checkpointId); + this.frontStore.clear(); + this.middleStore.clear(); + this.kvStore.clearCache(); + } + + @Override + public V get(long checkpointId, String key) { + // get from current cp cache + StorageRecord storageRecord = frontStore.get(key); + if (storageRecord != null) { + return storageRecord.getValue(); + } + + // get from not commit cp info + List checkpointIds = new ArrayList<>(middleStore.keySet()); + Collections.sort(checkpointIds); + for (int i = checkpointIds.size() - 1; i >= 0; i--) { + Map cpStore = middleStore.get(checkpointIds.get(i)); + if (cpStore != null) { + if (cpStore.containsKey(key)) { + byte[] cpData = cpStore.get(key); + storageRecord = toStorageRecord(cpData); + return storageRecord.getValue(); + } + } + } + + try { + Map remoteData = super.kvStore.get(key); + if (remoteData != null) { + for (Entry entry : remoteData.entrySet()) { + if (entry.getKey() > 0) { + StorageRecord tmp = toStorageRecord(entry.getValue()); + if (tmp.getCheckpointId() < checkpointId) { + if (storageRecord == null) { + storageRecord = tmp; + } else if (storageRecord.getCheckpointId() < tmp.getCheckpointId()) { + storageRecord = tmp; + } + } + } + } + if (storageRecord != null) { + return storageRecord.getValue(); + } + } + } catch (Exception e) { + LOG.error("get checkpointId:" + checkpointId + " key:" + key, e); + throw new StateException(e); + } + return null; + } + + @Override + public void ackCommit(long checkpointId) { + LOG.info("do ackCommit checkpointId:{}", checkpointId); + middleStore.remove(checkpointId); + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/MVStateStoreManager.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/MVStateStoreManager.java new file mode 100644 index 00000000..cbb59a17 --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/MVStateStoreManager.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import io.ray.streaming.state.StateException; +import io.ray.streaming.state.StorageRecord; +import io.ray.streaming.state.store.KeyValueStore; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** This class define the multi-version store strategy, which leverages external storage's mvcc. */ +public class MVStateStoreManager extends AbstractStateStoreManager { + + public MVStateStoreManager(KeyValueStore> kvStore) { + super(kvStore); + } + + @Override + public void finish(long checkpointId) { + Map currentStateRecords = new HashMap<>(); + for (Entry> entry : frontStore.entrySet()) { + currentStateRecords.put(entry.getKey(), toBytes(entry.getValue())); + } + + middleStore.put(checkpointId, currentStateRecords); + frontStore.clear(); + } + + @Override + public void commit(long checkpointId) { + // write to external storage + List checkpointIds = new ArrayList<>(middleStore.keySet()); + Collections.sort(checkpointIds); + + for (int i = checkpointIds.size() - 1; i >= 0; i--) { + long commitBatchId = checkpointIds.get(i); + if (commitBatchId > checkpointId) { + continue; + } + + Map commitRecords = middleStore.get(commitBatchId); + + try { + for (Entry entry : commitRecords.entrySet()) { + + Map remoteData = this.kvStore.get(entry.getKey()); + if (remoteData == null) { + remoteData = new HashMap<>(); + } + + remoteData.put(commitBatchId, entry.getValue()); + + this.kvStore.put(entry.getKey(), remoteData); + } + this.kvStore.flush(); + } catch (Exception e) { + throw new StateException(e); + } + } + } + + @Override + public void rollBack(long checkpointId) { + this.frontStore.clear(); + this.middleStore.clear(); + this.kvStore.clearCache(); + } + + @Override + public V get(long checkpointId, String key) { + StorageRecord valueArray = frontStore.get(key); + if (valueArray != null) { + return valueArray.getValue(); + } else { + List checkpointIds = new ArrayList<>(middleStore.keySet()); + Collections.sort(checkpointIds); + + for (int i = checkpointIds.size() - 1; i >= 0; i--) { + if (checkpointIds.get(i) > checkpointId) { + continue; + } + + Map records = middleStore.get(checkpointIds.get(i)); + if (records != null) { + if (records.containsKey(key)) { + byte[] bytes = records.get(key); + return toStorageRecord(bytes).getValue(); + } + } + } + + // get from external storage + try { + Map remoteData = this.kvStore.get(key); + if (remoteData != null) { + checkpointIds = new ArrayList<>(remoteData.keySet()); + Collections.sort(checkpointIds); + + for (int i = checkpointIds.size() - 1; i >= 0; i--) { + if (checkpointIds.get(i) > checkpointId) { + continue; + } + + byte[] bytes = remoteData.get(checkpointIds.get(i)); + return toStorageRecord(bytes).getValue(); + } + } + } catch (Exception e) { + throw new StateException(e); + } + } + return null; + } + + @Override + public void put(long checkpointId, String k, V v) { + frontStore.put(k, new StorageRecord<>(checkpointId, v)); + } + + @Override + public void ackCommit(long checkpointId) { + List checkpointIds = new ArrayList<>(middleStore.keySet()); + Collections.sort(checkpointIds); + + for (int i = checkpointIds.size() - 1; i >= 0; i--) { + long commitBatchId = checkpointIds.get(i); + if (commitBatchId > checkpointId) { + continue; + } + + this.middleStore.remove(commitBatchId); + } + } +} diff --git a/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/StateStoreManagerProxy.java b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/StateStoreManagerProxy.java new file mode 100644 index 00000000..eb4b2a3a --- /dev/null +++ b/streaming/java/streaming-state/src/main/java/io/ray/streaming/state/strategy/StateStoreManagerProxy.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import io.ray.streaming.state.StateStoreManager; +import io.ray.streaming.state.backend.AbstractKeyStateBackend; +import io.ray.streaming.state.backend.StateStrategy; +import io.ray.streaming.state.keystate.desc.AbstractStateDescriptor; +import io.ray.streaming.state.store.KeyValueStore; +import java.util.Map; + +/** + * This class support ITransactionState. + * + *

Based on the given StorageMode, different implementation instance of the AbstractStateStrategy + * class will be created. All method calls will be delegated to the strategy instance. + */ +public abstract class StateStoreManagerProxy implements StateStoreManager { + + protected final AbstractStateStoreManager stateStrategy; + private final AbstractKeyStateBackend keyStateBackend; + + public StateStoreManagerProxy( + AbstractKeyStateBackend keyStateBackend, AbstractStateDescriptor stateDescriptor) { + this.keyStateBackend = keyStateBackend; + KeyValueStore> backStorage = + keyStateBackend.getBackStorage(stateDescriptor); + StateStrategy stateStrategy = keyStateBackend.getStateStrategy(); + switch (stateStrategy) { + case DUAL_VERSION: + this.stateStrategy = new DualStateStoreManager<>(backStorage); + break; + case SINGLE_VERSION: + this.stateStrategy = new MVStateStoreManager<>(backStorage); + break; + default: + throw new UnsupportedOperationException("store vertexType not support"); + } + } + + protected void setKeyGroupIndex(int index) { + this.stateStrategy.setKeyGroupIndex(index); + } + + @Override + public void finish(long checkpointId) { + this.stateStrategy.finish(checkpointId); + } + + /** The commit can be used in another thread to reach async state commit. */ + @Override + public void commit(long checkpointId) { + this.stateStrategy.commit(checkpointId); + } + + /** The ackCommit must be called after commit in the same thread. */ + @Override + public void ackCommit(long checkpointId, long timeStamp) { + this.stateStrategy.ackCommit(checkpointId); + } + + @Override + public void rollBack(long checkpointId) { + this.stateStrategy.rollBack(checkpointId); + } + + public void close() { + this.stateStrategy.close(); + } + + public V get(String key) { + this.stateStrategy.setKeyGroupIndex(keyStateBackend.getKeyGroupIndex()); + return this.stateStrategy.get(this.keyStateBackend.getCheckpointId(), key); + } + + public void put(String key, V value) { + this.stateStrategy.setKeyGroupIndex(keyStateBackend.getKeyGroupIndex()); + this.stateStrategy.put(this.keyStateBackend.getCheckpointId(), key, value); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/backend/KeyStateBackendTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/backend/KeyStateBackendTest.java new file mode 100644 index 00000000..c690aecf --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/backend/KeyStateBackendTest.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.backend; + +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import java.util.Arrays; +import java.util.HashMap; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class KeyStateBackendTest { + + private AbstractStateBackend stateBackend; + private KeyStateBackend keyStateBackend; + + public void testGetValueState() { + keyStateBackend.setCheckpointId(1L); + ValueStateDescriptor valueStateDescriptor = + ValueStateDescriptor.build("value", String.class, null); + valueStateDescriptor.setTableName("kepler_hlg_ut"); + ValueState valueState = keyStateBackend.getValueState(valueStateDescriptor); + + valueState.setCurrentKey("1"); + valueState.update("hello"); + Assert.assertEquals(valueState.get(), "hello"); + + valueState.update("hello1"); + Assert.assertEquals(valueState.get(), "hello1"); + + valueState.setCurrentKey("2"); + Assert.assertEquals(valueState.get(), null); + + valueState.update("eagle"); + Assert.assertEquals(valueState.get(), "eagle"); + + keyStateBackend.rollBack(1); + valueState.setCurrentKey("1"); + Assert.assertEquals(valueState.get(), null); + valueState.setCurrentKey("2"); + Assert.assertEquals(valueState.get(), null); + + valueState.setCurrentKey("1"); + valueState.update("eagle"); + keyStateBackend.finish(1); + + keyStateBackend.setCheckpointId(2); + valueState.setCurrentKey("2"); + valueState.update("tim"); + + valueState.setCurrentKey("2-1"); + valueState.update("jim"); + keyStateBackend.finish(2); + + keyStateBackend.setCheckpointId(3); + valueState.setCurrentKey("3"); + valueState.update("lucy"); + keyStateBackend.finish(3); + + keyStateBackend.setCheckpointId(4); + valueState.setCurrentKey("4"); + valueState.update("eric"); + keyStateBackend.finish(4); + + keyStateBackend.setCheckpointId(5); + valueState.setCurrentKey("4"); + valueState.update("eric-1"); + valueState.setCurrentKey("5"); + valueState.update("jack"); + keyStateBackend.finish(5); + keyStateBackend.commit(5); + + keyStateBackend.setCheckpointId(6); + valueState.setCurrentKey("5"); + Assert.assertEquals(valueState.get(), "jack"); + + valueState.setCurrentKey("4"); + Assert.assertEquals(valueState.get(), "eric-1"); + + valueState.setCurrentKey(4); + valueState.update("if-ttt"); + Assert.assertEquals(valueState.get(), "if-ttt"); + + keyStateBackend.setCheckpointId(7); + valueState.setCurrentKey(9); + valueState.update("6666"); + + keyStateBackend.rollBack(5); + keyStateBackend.setCheckpointId(6); + valueState.setCurrentKey("4"); + Assert.assertEquals(valueState.get(), "eric-1"); + valueState.setCurrentKey("5"); + Assert.assertEquals(valueState.get(), "jack"); + valueState.setCurrentKey("9"); + Assert.assertNull(valueState.get()); + } + + public void testGetListState() { + keyStateBackend.setCheckpointId(1l); + ListStateDescriptor listStateDescriptor = + ListStateDescriptor.build("list", String.class); + listStateDescriptor.setTableName("kepler_hlg_ut"); + ListState listState = keyStateBackend.getListState(listStateDescriptor); + + listState.setCurrentKey("1"); + listState.add("hello1"); + Assert.assertEquals(listState.get(), Arrays.asList("hello1")); + + listState.add("hello2"); + Assert.assertEquals(listState.get(), Arrays.asList("hello1", "hello2")); + + listState.setCurrentKey("2"); + Assert.assertEquals(listState.get(), Arrays.asList()); + + listState.setCurrentKey("2"); + listState.add("eagle"); + listState.setCurrentKey("1"); + Assert.assertEquals(listState.get(), Arrays.asList("hello1", "hello2")); + listState.setCurrentKey("2"); + Assert.assertEquals(listState.get(), Arrays.asList("eagle")); + + keyStateBackend.rollBack(1); + listState.setCurrentKey("1"); + Assert.assertEquals(listState.get(), Arrays.asList()); + listState.setCurrentKey("2"); + Assert.assertEquals(listState.get(), Arrays.asList()); + + listState.setCurrentKey("1"); + listState.add("eagle"); + listState.add("eagle-2"); + keyStateBackend.finish(1); + + keyStateBackend.setCheckpointId(2); + listState.setCurrentKey("2"); + listState.add("tim"); + + listState.setCurrentKey("2-1"); + listState.add("jim"); + keyStateBackend.finish(2); + + keyStateBackend.setCheckpointId(3); + listState.setCurrentKey("3"); + listState.add("lucy"); + keyStateBackend.finish(3); + + keyStateBackend.setCheckpointId(4); + listState.setCurrentKey("4"); + listState.add("eric"); + keyStateBackend.finish(4); + + keyStateBackend.setCheckpointId(5); + listState.setCurrentKey("4"); + listState.add("eric-1"); + Assert.assertEquals(listState.get(), Arrays.asList("eric", "eric-1")); + + listState.setCurrentKey("5"); + listState.add("jack"); + keyStateBackend.finish(5); + keyStateBackend.commit(5); + + keyStateBackend.setCheckpointId(6); + listState.setCurrentKey("5"); + Assert.assertEquals(listState.get(), Arrays.asList("jack")); + + listState.setCurrentKey("4"); + Assert.assertEquals(listState.get(), Arrays.asList("eric", "eric-1")); + + listState.setCurrentKey(4); + listState.add("if-ttt"); + Assert.assertEquals(listState.get(), Arrays.asList("eric", "eric-1", "if-ttt")); + + keyStateBackend.setCheckpointId(7); + listState.setCurrentKey(9); + listState.add("6666"); + + keyStateBackend.rollBack(5); + keyStateBackend.setCheckpointId(6); + listState.setCurrentKey("4"); + Assert.assertEquals(listState.get(), Arrays.asList("eric", "eric-1")); + listState.setCurrentKey("5"); + Assert.assertEquals(listState.get(), Arrays.asList("jack")); + listState.setCurrentKey("9"); + Assert.assertEquals(listState.get(), Arrays.asList()); + } + + public void testGetMapState() { + keyStateBackend.setCheckpointId(1l); + MapStateDescriptor mapStateDescriptor = + MapStateDescriptor.build("map", String.class, String.class); + mapStateDescriptor.setTableName("kepler_hlg_ut"); + MapState mapState = keyStateBackend.getMapState(mapStateDescriptor); + + mapState.setCurrentKey("1"); + mapState.put("hello1", "world1"); + Assert.assertEquals(mapState.get("hello1"), "world1"); + + mapState.put("hello2", "world2"); + Assert.assertEquals(mapState.get("hello2"), "world2"); + Assert.assertEquals(mapState.get("hello1"), "world1"); + Assert.assertEquals(mapState.get("hello3"), null); + + mapState.setCurrentKey("2"); + // Assert.assertEquals(mapState.iterator(), (new HashMap())); + + mapState.setCurrentKey("2"); + mapState.put("eagle", "eagle-1"); + mapState.setCurrentKey("1"); + Assert.assertEquals(mapState.get("hello1"), "world1"); + mapState.setCurrentKey("2"); + Assert.assertEquals(mapState.get("eagle"), "eagle-1"); + Assert.assertEquals(mapState.get("xxx"), null); + + keyStateBackend.rollBack(1); + mapState.setCurrentKey("1"); + Assert.assertEquals(mapState.iterator(), (new HashMap()).entrySet().iterator()); + mapState.setCurrentKey("2"); + Assert.assertEquals(mapState.iterator(), (new HashMap()).entrySet().iterator()); + + mapState.setCurrentKey("1"); + mapState.put("eagle", "eagle-1"); + mapState.put("eagle-2", "eagle-3"); + keyStateBackend.finish(1); + + keyStateBackend.setCheckpointId(2); + mapState.setCurrentKey("2"); + mapState.put("tim", "tina"); + + mapState.setCurrentKey("2-1"); + mapState.put("jim", "tick"); + keyStateBackend.finish(2); + + keyStateBackend.setCheckpointId(3); + mapState.setCurrentKey("3"); + mapState.put("lucy", "ja"); + keyStateBackend.finish(3); + + keyStateBackend.setCheckpointId(4); + mapState.setCurrentKey("4"); + mapState.put("eric", "sam"); + keyStateBackend.finish(4); + + keyStateBackend.setCheckpointId(5); + mapState.setCurrentKey("4"); + mapState.put("eric-1", "zxy"); + Assert.assertEquals(mapState.get("eric-1"), "zxy"); + Assert.assertEquals(mapState.get("eric"), "sam"); + + mapState.setCurrentKey("5"); + mapState.put("jack", "zhang"); + keyStateBackend.finish(5); + keyStateBackend.commit(5); + + keyStateBackend.setCheckpointId(6); + mapState.setCurrentKey("5"); + Assert.assertEquals(mapState.get("jack"), "zhang"); + mapState.put("hlll", "gggg"); + + mapState.setCurrentKey("4"); + Assert.assertEquals(mapState.get("eric-1"), "zxy"); + Assert.assertEquals(mapState.get("eric"), "sam"); + + mapState.setCurrentKey(4); + mapState.put("if-ttt", "if-ggg"); + Assert.assertEquals(mapState.get("if-ttt"), "if-ggg"); + + keyStateBackend.setCheckpointId(7); + mapState.setCurrentKey(9); + mapState.put("6666", "7777"); + + keyStateBackend.rollBack(5); + keyStateBackend.setCheckpointId(6); + mapState.setCurrentKey("4"); + Assert.assertEquals(mapState.get("eric-1"), "zxy"); + Assert.assertEquals(mapState.get("eric"), "sam"); + Assert.assertNull(mapState.get("if-ttt")); + + mapState.setCurrentKey("5"); + Assert.assertNull(mapState.get("hlll")); + mapState.setCurrentKey("9"); + Assert.assertNull(mapState.get("6666")); + } + + @Test + public void testMem() { + stateBackend = StateBackendBuilder.buildStateBackend(new HashMap<>()); + keyStateBackend = new KeyStateBackend(4, new KeyGroup(2, 3), stateBackend); + testGetValueState(); + testGetListState(); + testGetMapState(); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyMapStoreSerializationTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyMapStoreSerializationTest.java new file mode 100644 index 00000000..0f64ea12 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyMapStoreSerializationTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.impl; + +import io.ray.streaming.state.serialization.impl.DefaultKeyMapStoreSerializer; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class DefaultKeyMapStoreSerializationTest { + + private DefaultKeyMapStoreSerializer> defaultKMapStoreSerDe; + + @BeforeClass + public void setUp() { + this.defaultKMapStoreSerDe = new DefaultKeyMapStoreSerializer<>(); + } + + @Test + public void testSerKey() { + String key = "hello"; + byte[] result = this.defaultKMapStoreSerDe.serializeKey(key); + String keyWithPrefix = this.defaultKMapStoreSerDe.generateRowKeyPrefix(key.toString()); + Assert.assertEquals(result, keyWithPrefix.getBytes()); + } + + @Test + public void testSerUKey() { + String subKey = "hell1"; + byte[] result = this.defaultKMapStoreSerDe.serializeUKey(subKey); + Assert.assertEquals(subKey, this.defaultKMapStoreSerDe.deserializeUKey(result)); + } + + @Test + public void testSerUValue() { + Map value = new HashMap<>(); + value.put("foo", "bar"); + byte[] result = this.defaultKMapStoreSerDe.serializeUValue(value); + Assert.assertEquals(value, this.defaultKMapStoreSerDe.deserializeUValue(result)); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyValueStoreSerializationTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyValueStoreSerializationTest.java new file mode 100644 index 00000000..332aa829 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/DefaultKeyValueStoreSerializationTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.impl; + +import io.ray.streaming.state.serialization.impl.DefaultKeyValueStoreSerialization; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class DefaultKeyValueStoreSerializationTest { + + DefaultKeyValueStoreSerialization serDe = + new DefaultKeyValueStoreSerialization<>(); + byte[] ret; + + @Test + public void testSerKey() throws Exception { + ret = serDe.serializeKey("key"); + String key = new String(ret); + Assert.assertEquals(key.indexOf("key"), 5); + } + + @Test + public void testSerValue() throws Exception { + ret = serDe.serializeValue(5); + Assert.assertEquals(ret.length, 2); + Assert.assertEquals((int) serDe.deserializeValue(ret), 5); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyMapStoreTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyMapStoreTest.java new file mode 100644 index 00000000..f16103e5 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyMapStoreTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.impl; + +import com.google.common.collect.Maps; +import io.ray.streaming.state.backend.AbstractStateBackend; +import io.ray.streaming.state.backend.StateBackendBuilder; +import io.ray.streaming.state.store.KeyMapStore; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class MemoryKeyMapStoreTest { + + private AbstractStateBackend stateBackend; + private KeyMapStore IKeyMapStore; + + @BeforeClass + public void setUp() { + stateBackend = StateBackendBuilder.buildStateBackend(new HashMap()); + IKeyMapStore = stateBackend.getKeyMapStore("test-table"); + } + + @Test + public void testCase() { + try { + Assert.assertNull(IKeyMapStore.get("hello")); + Map map = Maps.newHashMap(); + map.put("1", "1-1"); + map.put("2", "2-1"); + + IKeyMapStore.put("hello", map); + Assert.assertEquals(IKeyMapStore.get("hello"), map); + + Map map2 = Maps.newHashMap(); + map.put("3", "3-1"); + map.put("4", "4-1"); + IKeyMapStore.put("hello", map2); + Assert.assertNotEquals(IKeyMapStore.get("hello"), map); + Assert.assertEquals(IKeyMapStore.get("hello"), map2); + + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyValueStoreTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyValueStoreTest.java new file mode 100644 index 00000000..84487d53 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/impl/MemoryKeyValueStoreTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.impl; + +import io.ray.streaming.state.backend.AbstractStateBackend; +import io.ray.streaming.state.backend.StateBackendBuilder; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class MemoryKeyValueStoreTest { + + private AbstractStateBackend stateBackend; + private io.ray.streaming.state.store.KeyValueStore KeyValueStore; + + @BeforeClass + public void setUp() { + Map config = new HashMap<>(); + stateBackend = StateBackendBuilder.buildStateBackend(config); + KeyValueStore = stateBackend.getKeyValueStore("kepler_hlg_ut"); + } + + @Test + public void testCase() { + try { + KeyValueStore.put("hello", "world"); + Assert.assertEquals(KeyValueStore.get("hello"), "world"); + KeyValueStore.put("hello", "world1"); + Assert.assertEquals(KeyValueStore.get("hello"), "world1"); + Assert.assertNull(KeyValueStore.get("hello1")); + } catch (IOException e) { + e.printStackTrace(); + } + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/KeyGroupAssignmentTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/KeyGroupAssignmentTest.java new file mode 100644 index 00000000..22d5d36d --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/KeyGroupAssignmentTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate; + +import static org.testng.Assert.assertEquals; + +import org.testng.annotations.Test; + +public class KeyGroupAssignmentTest { + + @Test + public void testComputeKeyGroupRangeForOperatorIndex() throws Exception { + KeyGroup keyGroup = KeyGroupAssignment.getKeyGroup(4096, 1, 0); + assertEquals(keyGroup.getStartIndex(), 0); + assertEquals(keyGroup.getEndIndex(), 4095); + assertEquals(keyGroup.size(), 4096); + + KeyGroup keyGroup2 = KeyGroupAssignment.getKeyGroup(4096, 2, 0); + assertEquals(keyGroup2.getStartIndex(), 0); + assertEquals(keyGroup2.getEndIndex(), 2047); + assertEquals(keyGroup2.size(), 2048); + + keyGroup = KeyGroupAssignment.getKeyGroup(4096, 3, 0); + assertEquals(keyGroup.getStartIndex(), 0); + assertEquals(keyGroup.getEndIndex(), 1365); + assertEquals(keyGroup.size(), 1366); + + keyGroup2 = KeyGroupAssignment.getKeyGroup(4096, 3, 1); + assertEquals(keyGroup2.getStartIndex(), 1366); + assertEquals(keyGroup2.getEndIndex(), 2730); + assertEquals(keyGroup2.size(), 1365); + + KeyGroup keyGroup3 = KeyGroupAssignment.getKeyGroup(4096, 3, 2); + assertEquals(keyGroup3.getStartIndex(), 2731); + assertEquals(keyGroup3.getEndIndex(), 4095); + assertEquals(keyGroup3.size(), 1365); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ListStateDescriptorTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ListStateDescriptorTest.java new file mode 100644 index 00000000..4475109b --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ListStateDescriptorTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import io.ray.streaming.state.config.ConfigKey; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ListStateDescriptorTest { + + @Test + public void test() { + ListStateDescriptor descriptor = + ListStateDescriptor.build("lsdTest", Integer.class, true); + descriptor.setTableName("table"); + Assert.assertTrue(descriptor.isOperatorList()); + + descriptor.setPartitionNumber(3); + descriptor.setIndex(0); + + Assert.assertEquals(descriptor.getPartitionNumber(), 3); + Assert.assertEquals(descriptor.getIndex(), 0); + + Assert.assertEquals( + descriptor.getIdentify(), + "lsdTest" + ConfigKey.DELIMITER + "3" + ConfigKey.DELIMITER + "0"); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/MapStateDescriptorTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/MapStateDescriptorTest.java new file mode 100644 index 00000000..18a1844a --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/MapStateDescriptorTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class MapStateDescriptorTest { + + @Test + public void test() { + MapStateDescriptor descriptor = + MapStateDescriptor.build("msdTest", String.class, Integer.class); + + descriptor.setTableName("table"); + Assert.assertEquals(descriptor.getTableName(), "table"); + Assert.assertEquals(descriptor.getName(), "msdTest"); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptorTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptorTest.java new file mode 100644 index 00000000..626e897f --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/desc/ValueStateDescriptorTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.desc; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ValueStateDescriptorTest { + + @Test + public void test() { + ValueStateDescriptor descriptor = + ValueStateDescriptor.build("vsdTest", Integer.class, 0); + Assert.assertEquals(descriptor.getDefaultValue().intValue(), 0); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ListStateImplTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ListStateImplTest.java new file mode 100644 index 00000000..0323ca7b --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ListStateImplTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class ListStateImplTest { + + ListStateImpl listState; + KeyStateBackend keyStateBackend; + + @BeforeClass + public void setUp() throws Exception { + keyStateBackend = + new KeyStateBackend(1, new KeyGroup(1, 2), new MemoryStateBackend(new HashMap<>())); + ListStateDescriptor descriptor = + ListStateDescriptor.build("ListStateImplTest", Integer.class); + descriptor.setTableName("table"); + + listState = (ListStateImpl) keyStateBackend.getListState(descriptor); + } + + @Test + public void testAddGet() throws Exception { + keyStateBackend.setContext(1L, 1); + List list = listState.get(); + Assert.assertEquals(list.size(), 0); + + listState.add(1); + listState.add(2); + + Assert.assertEquals(listState.get(), Arrays.asList(1, 2)); + + listState.add(3); + Assert.assertEquals(listState.get(), Arrays.asList(1, 2, 3)); + + list = listState.get(); + list.set(1, -1); + listState.add(4); + Assert.assertEquals(listState.get(), Arrays.asList(1, -1, 3, 4)); + + keyStateBackend.setCurrentKey(2); + + listState.add(5); + listState.add(6); + + Assert.assertEquals(listState.get(), Arrays.asList(5, 6)); + } + + @Test(dependsOnMethods = {"testAddGet"}) + public void testUpdate() throws Exception { + Assert.assertEquals(listState.get(), Arrays.asList(5, 6)); + + listState.update(Arrays.asList(7, 8, 9)); + + List list = listState.get(); + Assert.assertEquals(list, Arrays.asList(7, 8, 9)); + + list.set(1, 10); + listState.update(list); + Assert.assertEquals(list, Arrays.asList(7, 10, 9)); + + listState.update(null); + Assert.assertEquals(listState.get().size(), 0); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/MapStateImplTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/MapStateImplTest.java new file mode 100644 index 00000000..77d6e00b --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/MapStateImplTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import com.google.common.collect.ImmutableMap; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class MapStateImplTest { + + MapStateImpl mapState; + KeyStateBackend keyStateBackend; + + @BeforeClass + public void setUp() throws Exception { + keyStateBackend = + new KeyStateBackend(1, new KeyGroup(1, 2), new MemoryStateBackend(new HashMap<>())); + MapStateDescriptor descriptor = + MapStateDescriptor.build("MapStateImplTest", Integer.class, String.class); + descriptor.setTableName("table"); + mapState = (MapStateImpl) keyStateBackend.getMapState(descriptor); + } + + @Test + public void testPuTGet() throws Exception { + keyStateBackend.setContext(1L, 1); + + Assert.assertEquals(mapState.get().size(), 0); + + mapState.put(1, "1"); + mapState.put(2, "2"); + + Assert.assertTrue(mapState.contains(1)); + Assert.assertTrue(mapState.contains(2)); + Assert.assertFalse(mapState.contains(3)); + + Assert.assertEquals("1", mapState.get(1)); + Assert.assertEquals("2", mapState.get(2)); + + mapState.remove(1); + Assert.assertFalse(mapState.contains(1)); + Assert.assertTrue(mapState.contains(2)); + + mapState.remove(2); + mapState.putAll(ImmutableMap.of(1, "1", 2, "2")); + Assert.assertEquals("1", mapState.get(1)); + Assert.assertEquals("2", mapState.get(2)); + } + + @Test(dependsOnMethods = {"testPuTGet"}) + public void testUpdate() throws Exception { + Assert.assertEquals(mapState.get().size(), 2); + + mapState.update(ImmutableMap.of(3, "3", 4, "4")); + Assert.assertEquals(mapState.get(3), "3"); + Assert.assertEquals(mapState.get(4), "4"); + Assert.assertEquals(mapState.get().size(), 2); + + Map map = ImmutableMap.of(5, "5", 6, "6"); + mapState.update(map); + Assert.assertEquals(mapState.get(), map); + + mapState.update(null); + Assert.assertEquals(mapState.get().size(), 0); + } + + @Test + public void testFailover() throws Exception { + keyStateBackend.setContext(1L, 1); + + Assert.assertEquals(mapState.get().size(), 0); + + mapState.put(1, "1"); + mapState.put(2, "2"); + + Assert.assertTrue(mapState.contains(1)); + Assert.assertTrue(mapState.contains(2)); + Assert.assertFalse(mapState.contains(3)); + + Assert.assertEquals("1", mapState.get(1)); + Assert.assertEquals("2", mapState.get(2)); + + mapState.remove(1); + Assert.assertFalse(mapState.contains(1)); + Assert.assertTrue(mapState.contains(2)); + + mapState.remove(2); + mapState.putAll(ImmutableMap.of(1, "1", 2, "2")); + Assert.assertEquals("1", mapState.get(1)); + Assert.assertEquals("2", mapState.get(2)); + + keyStateBackend.finish(5); + + Assert.assertEquals("1", mapState.get(1)); + Assert.assertEquals("2", mapState.get(2)); + + mapState.put(2, "3"); + Assert.assertEquals("3", mapState.get(2)); + + keyStateBackend.finish(6); + + keyStateBackend.commit(5); + Assert.assertEquals("3", mapState.get(2)); + + keyStateBackend.commit(6); + keyStateBackend.ackCommit(5, 0); + keyStateBackend.ackCommit(6, 1); + + mapState.put(2, "5"); + Assert.assertEquals("5", mapState.get(2)); + mapState.update(null); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImplTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImplTest.java new file mode 100644 index 00000000..4639abf0 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/OperatorStateImplTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.OperatorStateBackend; +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class OperatorStateImplTest { + + OperatorStateImpl operatorState; + ListStateDescriptor descriptor; + OperatorStateBackend operatorStateBackend; + + @BeforeClass + public void setUp() { + String table_name = "operatorState"; + Map config = new HashMap<>(); + + operatorStateBackend = new OperatorStateBackend(new MemoryStateBackend(config)); + + descriptor = + ListStateDescriptor.build( + "OperatorStateImplTest" + System.currentTimeMillis(), Integer.class, true); + descriptor.setPartitionNumber(1); + descriptor.setIndex(0); + descriptor.setTableName(table_name); + + operatorState = (OperatorStateImpl) operatorStateBackend.getSplitListState(descriptor); + } + + @Test + public void testInit() throws Exception { + operatorStateBackend.setCheckpointId(1L); + List list = operatorState.get(); + Assert.assertEquals(list.size(), 0); + + for (int i = 0; i < 100; i++) { + operatorState.add(i); + } + Assert.assertEquals(operatorState.get().size(), 100); + operatorStateBackend.finish(1L); + operatorStateBackend.commit(1L); + operatorStateBackend.ackCommit(1L, 0); + + operatorStateBackend.finish(5L); + operatorStateBackend.commit(5L); + operatorStateBackend.ackCommit(5L, 0); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ValueStateImplTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ValueStateImplTest.java new file mode 100644 index 00000000..efc74be0 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/impl/ValueStateImplTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.impl; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import java.util.HashMap; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class ValueStateImplTest { + + ValueStateImpl valueState; + KeyStateBackend keyStateBackend; + + @BeforeClass + public void setUp() throws Exception { + keyStateBackend = + new KeyStateBackend(1, new KeyGroup(1, 2), new MemoryStateBackend(new HashMap<>())); + ValueStateDescriptor descriptor = + ValueStateDescriptor.build("ValueStateImplTest", String.class, "hello"); + descriptor.setTableName("table"); + + valueState = (ValueStateImpl) keyStateBackend.getValueState(descriptor); + } + + @Test + public void testUpdateGet() throws Exception { + keyStateBackend.setContext(1L, 1); + + Assert.assertEquals(valueState.get(), "hello"); + + String str = valueState.get(); + + valueState.update(str + " world"); + Assert.assertEquals(valueState.get(), "hello world"); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerTest.java new file mode 100644 index 00000000..1c9c5fe9 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ListStateStoreManagerTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import java.util.Arrays; +import java.util.List; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class ListStateStoreManagerTest extends StateStoreManagerTest { + + ListStateStoreManagerProxy proxy; + + @BeforeClass + public void setUp() { + ListStateDescriptor descriptor = ListStateDescriptor.build("list", Integer.class); + descriptor.setTableName("tableName"); + keyStateBackend.setContext(1L, "key"); + proxy = new ListStateStoreManagerProxy<>(keyStateBackend, descriptor); + } + + @Test + public void test() throws Exception { + Assert.assertEquals(proxy.getListState().get().size(), 0); + + List list = Arrays.asList(1, 2, 3); + proxy.put("key1", list); + Assert.assertEquals(proxy.get("key1"), list); + + proxy.put("key1", Arrays.asList(1, 3)); + + proxy.put("key2", Arrays.asList(4, 5)); + Assert.assertEquals(proxy.get("key2"), Arrays.asList(4, 5)); + + Assert.assertEquals(proxy.get("key1"), Arrays.asList(1, 3)); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerTest.java new file mode 100644 index 00000000..fd99a248 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/MapStateStoreManagerTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class MapStateStoreManagerTest extends StateStoreManagerTest { + + MapStateStoreManagerProxy facade; + + @BeforeClass + public void setUp() { + MapStateDescriptor descriptor = + MapStateDescriptor.build("map", String.class, Integer.class); + descriptor.setTableName("tableName"); + keyStateBackend.setContext(1L, "key"); + facade = new MapStateStoreManagerProxy<>(keyStateBackend, descriptor); + } + + @Test + public void test() throws Exception { + Assert.assertEquals(facade.getMapState().get().size(), 0); + + Map map = new HashMap<>(); + map.put("key1", 1); + map.put("key2", 2); + facade.put("key1", map); + Assert.assertEquals(facade.get("key1"), map); + + map.remove("key1"); + facade.put("key2", map); + Assert.assertEquals(facade.get("key2"), map); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/StateStoreManagerTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/StateStoreManagerTest.java new file mode 100644 index 00000000..b919b5a0 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/StateStoreManagerTest.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.impl.MemoryStateBackend; +import io.ray.streaming.state.keystate.KeyGroup; +import java.util.HashMap; + +public class StateStoreManagerTest { + + protected KeyStateBackend keyStateBackend = + new KeyStateBackend( + 1, new KeyGroup(1, 1), new MemoryStateBackend(new HashMap())); +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerTest.java new file mode 100644 index 00000000..7802b05f --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/keystate/state/proxy/ValueStateStoreManagerTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.keystate.state.proxy; + +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ValueState; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class ValueStateStoreManagerTest extends StateStoreManagerTest { + + ValueStateStoreManagerProxy proxy; + + @BeforeClass + public void setUp() { + ValueStateDescriptor descriptor = + ValueStateDescriptor.build("value", Integer.class, 0); + descriptor.setTableName("tableName"); + keyStateBackend.setContext(1L, "key"); + proxy = new ValueStateStoreManagerProxy<>(keyStateBackend, descriptor); + } + + @Test + public void test() throws Exception { + ValueState state = proxy.getValueState(); + Assert.assertEquals(state.get().intValue(), 0); + + proxy.put("key1", 2); + Assert.assertEquals(proxy.get("key1").intValue(), 2); + + proxy.put("key1", 3); + Assert.assertEquals(proxy.get("key1").intValue(), 3); + + proxy.put("key2", 9); + Assert.assertEquals(proxy.get("key2").intValue(), 9); + + proxy.put("key2", 6); + Assert.assertEquals(proxy.get("key2").intValue(), 6); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/DualStateStrategyTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/DualStateStrategyTest.java new file mode 100644 index 00000000..44663287 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/DualStateStrategyTest.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import com.google.common.collect.Lists; +import io.ray.streaming.state.backend.BackendType; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.StateBackendBuilder; +import io.ray.streaming.state.backend.StateStrategy; +import io.ray.streaming.state.config.ConfigKey; +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class DualStateStrategyTest { + + private final String table = "kepler_cp_store"; + private final String defaultValue = "default"; + protected KeyStateBackend keyStateBackend; + Map config = new HashMap<>(); + private String currentTime; + + @BeforeClass + public void setUp() { + config.put(ConfigKey.STATE_STRATEGY_MODE, StateStrategy.DUAL_VERSION.name()); + currentTime = Long.toString(System.currentTimeMillis()); + } + + public void caseKV() { + ValueStateDescriptor valueStateDescriptor = + ValueStateDescriptor.build("VALUE-" + currentTime, String.class, defaultValue); + valueStateDescriptor.setTableName(table); + ValueState state = this.keyStateBackend.getValueState(valueStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.update("hello"); + state.setCurrentKey("2"); + state.update("world"); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "hello"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "world"); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.update("eagle"); + state.setCurrentKey(("4")); + state.update("alex"); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "eagle"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "alex"); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.update("tim"); + state.setCurrentKey(("4")); + state.update("scala"); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.update("cook"); + state.setCurrentKey(("2")); + state.update("inf"); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "tim"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "inf"); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "cook"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "scala"); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + // do rollback, all memory data is deleted. + this.keyStateBackend.rollBack(1); + this.keyStateBackend.setCheckpointId(1); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), defaultValue); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), defaultValue); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), defaultValue); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), defaultValue); + + this.keyStateBackend.setCheckpointId(4); + this.keyStateBackend.setCurrentKey("1"); + state.update("tim"); + this.keyStateBackend.finish(4); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("2"); + state.update("info"); + this.keyStateBackend.finish(5); + + this.keyStateBackend.setCheckpointId(6); + state.update("cook"); + this.keyStateBackend.finish(6); + + this.keyStateBackend.setCheckpointId(7); + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(), "tim"); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(), "cook"); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(), "tim"); + + this.keyStateBackend.rollBack(6); + } + + public void caseKVGap() { + ValueStateDescriptor valueStateDescriptor = + ValueStateDescriptor.build("value2-" + currentTime, String.class, defaultValue); + valueStateDescriptor.setTableName(table); + ValueState state = this.keyStateBackend.getValueState(valueStateDescriptor); + + this.keyStateBackend.setCheckpointId(5L); + + state.setCurrentKey("1"); + state.update("hello"); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "hello"); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("1"); + state.update("info"); + this.keyStateBackend.finish(5); + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.setCheckpointId(10); + Assert.assertEquals(state.get(), "info"); + this.keyStateBackend.finish(10); + this.keyStateBackend.commit(10); + this.keyStateBackend.ackCommit(10, 10); + + this.keyStateBackend.setCheckpointId(15); + state.update("world"); + this.keyStateBackend.finish(15); + this.keyStateBackend.commit(15); + this.keyStateBackend.ackCommit(15, 15); + + this.keyStateBackend.setCheckpointId(11); + this.keyStateBackend.rollBack(11); + Assert.assertEquals(state.get(), "info"); + + this.keyStateBackend.setCheckpointId(15); + state.update("world2"); + this.keyStateBackend.finish(15); + this.keyStateBackend.commit(15); + this.keyStateBackend.ackCommit(15, 15); + + this.keyStateBackend.setCheckpointId(11); + this.keyStateBackend.rollBack(11); + Assert.assertEquals(state.get(), "info"); + } + + public void caseKList() { + ListStateDescriptor listStateDescriptor = + ListStateDescriptor.build("LIST-" + currentTime, Integer.class); + listStateDescriptor.setTableName(table); + ListState state = this.keyStateBackend.getListState(listStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.add(1); + state.setCurrentKey("2"); + state.add(2); + + state.setCurrentKey("1"); + List result = state.get(); + Assert.assertEquals(result, Arrays.asList(1)); + state.setCurrentKey("2"); + Assert.assertEquals(state.get(), Arrays.asList(2)); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.add(3); + state.setCurrentKey(("4")); + state.add(4); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Arrays.asList(3)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Arrays.asList(4)); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.add(2); + state.setCurrentKey(("4")); + state.add(5); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.add(4); + state.setCurrentKey(("2")); + state.add(3); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), Arrays.asList(1, 2)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), Arrays.asList(2, 3)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Arrays.asList(3, 4)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Arrays.asList(4, 5)); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + // do rollback, all memory data is deleted. + this.keyStateBackend.rollBack(1); + this.keyStateBackend.setCheckpointId(1); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), Lists.newArrayList()); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), Lists.newArrayList()); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Lists.newArrayList()); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Lists.newArrayList()); + + this.keyStateBackend.setCheckpointId(4); + this.keyStateBackend.setCurrentKey("1"); + state.add(1); + this.keyStateBackend.finish(4); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("2"); + state.add(2); + this.keyStateBackend.finish(5); + + this.keyStateBackend.setCheckpointId(6); + state.add(3); + this.keyStateBackend.finish(6); + + this.keyStateBackend.setCheckpointId(7); + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(), Arrays.asList(2, 2, 3)); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.rollBack(5); + + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(), Arrays.asList(1)); + } + + public void caseKMap() { + MapStateDescriptor mapStateDescriptor = + MapStateDescriptor.build("MAP-" + currentTime, Integer.class, Integer.class); + mapStateDescriptor.setTableName(table); + MapState state = this.keyStateBackend.getMapState(mapStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.put(1, 1); + state.setCurrentKey("2"); + state.put(2, 2); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.put(3, 3); + state.setCurrentKey(("4")); + state.put(4, 4); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(4), Integer.valueOf(4)); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.put(5, 5); + state.setCurrentKey(("4")); + state.put(6, 6); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.put(7, 7); + state.setCurrentKey(("2")); + state.put(8, 8); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + Assert.assertEquals(state.get(5), Integer.valueOf(5)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(8), Integer.valueOf(8)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + Assert.assertEquals(state.get(7), Integer.valueOf(7)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(4), Integer.valueOf(4)); + Assert.assertEquals(state.get(6), Integer.valueOf(6)); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + // do rollback, memory data is deleted. + this.keyStateBackend.rollBack(1); + this.keyStateBackend.setCheckpointId(1); + state.setCurrentKey(("1")); + Assert.assertEquals(state.entries(), (new HashMap()).entrySet()); + state.setCurrentKey(("2")); + Assert.assertEquals(state.entries(), (new HashMap()).entrySet()); + state.setCurrentKey(("3")); + Assert.assertEquals(state.entries(), (new HashMap()).entrySet()); + state.setCurrentKey(("4")); + Assert.assertEquals(state.entries(), (new HashMap()).entrySet()); + + this.keyStateBackend.setCheckpointId(4); + this.keyStateBackend.setCurrentKey("1"); + state.put(1, 1); + this.keyStateBackend.finish(4); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("2"); + state.put(2, 2); + this.keyStateBackend.finish(5); + + this.keyStateBackend.setCheckpointId(6); + state.put(3, 3); + this.keyStateBackend.finish(6); + + this.keyStateBackend.setCheckpointId(7); + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.rollBack(5); + + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(3), null); + } + + @Test + public void testMem() { + config.put(ConfigKey.STATE_BACKEND_TYPE, BackendType.MEMORY.name()); + this.keyStateBackend = + new KeyStateBackend(10, new KeyGroup(1, 3), StateBackendBuilder.buildStateBackend(config)); + caseKV(); + caseKVGap(); + caseKList(); + caseKMap(); + } +} diff --git a/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/MVStateStrategyTest.java b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/MVStateStrategyTest.java new file mode 100644 index 00000000..79c15666 --- /dev/null +++ b/streaming/java/streaming-state/src/test/java/io/ray/streaming/state/strategy/MVStateStrategyTest.java @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.ray.streaming.state.strategy; + +import io.ray.streaming.state.backend.BackendType; +import io.ray.streaming.state.backend.KeyStateBackend; +import io.ray.streaming.state.backend.StateBackendBuilder; +import io.ray.streaming.state.backend.StateStrategy; +import io.ray.streaming.state.config.ConfigKey; +import io.ray.streaming.state.keystate.KeyGroup; +import io.ray.streaming.state.keystate.desc.ListStateDescriptor; +import io.ray.streaming.state.keystate.desc.MapStateDescriptor; +import io.ray.streaming.state.keystate.desc.ValueStateDescriptor; +import io.ray.streaming.state.keystate.state.ListState; +import io.ray.streaming.state.keystate.state.MapState; +import io.ray.streaming.state.keystate.state.ValueState; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class MVStateStrategyTest { + + private final String table = "kepler_hlg_ut"; + Map config = new HashMap<>(); + private KeyStateBackend keyStateBackend; + private String currentTime; + + @BeforeClass + public void setUp() { + config.put(ConfigKey.STATE_STRATEGY_MODE, StateStrategy.SINGLE_VERSION.name()); + currentTime = Long.toString(System.currentTimeMillis()); + } + + public void caseKV() { + + ValueStateDescriptor valueStateDescriptor = + ValueStateDescriptor.build("mvint-" + currentTime, String.class, ""); + valueStateDescriptor.setTableName(table); + ValueState state = this.keyStateBackend.getValueState(valueStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.update("hello"); + state.setCurrentKey("2"); + state.update("world"); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "hello"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "world"); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.update("eagle"); + state.setCurrentKey(("4")); + state.update("alex"); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "eagle"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "alex"); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.update("tim"); + state.setCurrentKey(("4")); + state.update("scala"); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.update("cook"); + state.setCurrentKey(("2")); + state.update("inf"); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "tim"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "inf"); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "cook"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "scala"); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + this.keyStateBackend.rollBack(2); + this.keyStateBackend.setCheckpointId(3); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "hello"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "world"); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "eagle"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "alex"); + + state.setCurrentKey(("1")); + state.update("tim"); + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4l); + + state.setCurrentKey(("4")); + state.update("scala"); + this.keyStateBackend.finish(4); + this.keyStateBackend.setCheckpointId(5l); + + state.setCurrentKey(("3")); + state.update("cook"); + this.keyStateBackend.finish(5); + + state.setCurrentKey(("2")); + state.update("info"); + this.keyStateBackend.finish(6); + this.keyStateBackend.setCheckpointId(6); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "tim"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "info"); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "cook"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "scala"); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.rollBack(5); + + this.keyStateBackend.setCheckpointId(6); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), "tim"); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), "world"); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), "cook"); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), "scala"); + } + + public void caseKList() { + ListStateDescriptor listStateDescriptor = + ListStateDescriptor.build("mvlist-" + currentTime, Integer.class); + listStateDescriptor.setTableName(table); + ListState state = this.keyStateBackend.getListState(listStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.add(1); + state.setCurrentKey("2"); + state.add(2); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), Arrays.asList(1)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), Arrays.asList(2)); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.add(3); + state.setCurrentKey(("4")); + state.add(4); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Arrays.asList(3)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Arrays.asList(4)); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.add(2); + state.setCurrentKey(("4")); + state.add(5); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.add(4); + state.setCurrentKey(("2")); + state.add(3); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), Arrays.asList(1, 2)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), Arrays.asList(2, 3)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Arrays.asList(3, 4)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Arrays.asList(4, 5)); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + this.keyStateBackend.rollBack(2); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(), Arrays.asList(1)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(), Arrays.asList(2)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(), Arrays.asList(3)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(), Arrays.asList(4)); + + this.keyStateBackend.setCheckpointId(4); + this.keyStateBackend.setCurrentKey("1"); + state.add(1); + this.keyStateBackend.finish(4); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("2"); + state.add(2); + this.keyStateBackend.finish(5); + + this.keyStateBackend.setCheckpointId(6); + state.add(3); + this.keyStateBackend.finish(6); + + this.keyStateBackend.setCheckpointId(7); + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(), Arrays.asList(1, 1)); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(), Arrays.asList(2, 2, 3)); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.rollBack(5); + + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(), Arrays.asList(1, 1)); + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(), Arrays.asList(2, 2)); + } + + public void caseKMap() { + MapStateDescriptor mapStateDescriptor = + MapStateDescriptor.build("mvmap-" + currentTime, Integer.class, Integer.class); + mapStateDescriptor.setTableName(table); + MapState state = this.keyStateBackend.getMapState(mapStateDescriptor); + + this.keyStateBackend.setCheckpointId(1l); + + state.setCurrentKey("1"); + state.put(1, 1); + state.setCurrentKey("2"); + state.put(2, 2); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + + this.keyStateBackend.finish(1); + + this.keyStateBackend.setCheckpointId(2); + state.setCurrentKey(("3")); + state.put(3, 3); + state.setCurrentKey(("4")); + state.put(4, 4); + + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(4), Integer.valueOf(4)); + + this.keyStateBackend.commit(1); + this.keyStateBackend.ackCommit(1, 1); + + this.keyStateBackend.finish(2); + this.keyStateBackend.setCheckpointId(3); + + state.setCurrentKey(("1")); + state.put(5, 5); + state.setCurrentKey(("4")); + state.put(6, 6); + + this.keyStateBackend.finish(3); + this.keyStateBackend.setCheckpointId(4); + + state.setCurrentKey(("3")); + state.put(7, 7); + state.setCurrentKey(("2")); + state.put(8, 8); + + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + Assert.assertEquals(state.get(5), Integer.valueOf(5)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(8), Integer.valueOf(8)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + Assert.assertEquals(state.get(7), Integer.valueOf(7)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(4), Integer.valueOf(4)); + Assert.assertEquals(state.get(6), Integer.valueOf(6)); + + this.keyStateBackend.commit(2); + this.keyStateBackend.ackCommit(2, 2); + + this.keyStateBackend.rollBack(2); + state.setCurrentKey(("1")); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + Assert.assertNull(state.get(5)); + state.setCurrentKey(("2")); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertNull(state.get(8)); + state.setCurrentKey(("3")); + Assert.assertEquals(state.get(3), Integer.valueOf(3)); + Assert.assertNull(state.get(7)); + state.setCurrentKey(("4")); + Assert.assertEquals(state.get(4), Integer.valueOf(4)); + Assert.assertNull(state.get(6)); + + this.keyStateBackend.setCheckpointId(4); + this.keyStateBackend.setCurrentKey("1"); + state.put(5, 5); + this.keyStateBackend.finish(4); + + this.keyStateBackend.setCheckpointId(5); + this.keyStateBackend.setCurrentKey("2"); + state.put(8, 8); + this.keyStateBackend.finish(5); + + this.keyStateBackend.setCheckpointId(6); + state.put(7, 7); + this.keyStateBackend.finish(6); + + this.keyStateBackend.setCheckpointId(7); + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + Assert.assertEquals(state.get(5), Integer.valueOf(5)); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(8), Integer.valueOf(8)); + Assert.assertEquals(state.get(7), Integer.valueOf(7)); + + this.keyStateBackend.commit(5); + this.keyStateBackend.ackCommit(5, 5); + + this.keyStateBackend.rollBack(5); + + this.keyStateBackend.setCurrentKey("1"); + Assert.assertEquals(state.get(1), Integer.valueOf(1)); + Assert.assertEquals(state.get(5), Integer.valueOf(5)); + + this.keyStateBackend.setCurrentKey("2"); + Assert.assertEquals(state.get(2), Integer.valueOf(2)); + Assert.assertEquals(state.get(8), Integer.valueOf(8)); + Assert.assertNull(state.get(7)); + } + + @Test + public void testMemMV() { + config.put(ConfigKey.STATE_BACKEND_TYPE, BackendType.MEMORY.name()); + this.keyStateBackend = + new KeyStateBackend(10, new KeyGroup(1, 3), StateBackendBuilder.buildStateBackend(config)); + caseKV(); + caseKList(); + caseKMap(); + } +} diff --git a/streaming/java/test.sh b/streaming/java/test.sh new file mode 100755 index 00000000..166457be --- /dev/null +++ b/streaming/java/test.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash + +# Cause the script to exit if a single command fails. +set -e +# Show explicitly which commands are currently running. +set -x + +ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) + +pushd "$ROOT_DIR" + echo "Check java code format." + # check google java style + mvn -T16 spotless:check + # check naming and others + mvn -T16 checkstyle:check +popd + +echo "build ray streaming" +bazel build //streaming/java:all + +# Check that ray libstreaming_java doesn't include symbols from ray by accident. +# Otherwise the symbols may conflict. +symbols_conflict=$(nm bazel-bin/streaming/libstreaming_java.so | grep TaskFinisherInterface || true) +if [ -n "${symbols_conflict}" ]; then + echo "streaming should not include symbols from ray: ${symbols_conflict}" + exit 1 +fi + +echo "Running streaming tests." +java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\ + org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml || +exit_code=$? +if [ -z ${exit_code+x} ]; then + exit_code=0 +fi +echo "Streaming TestNG results" +if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then + cat /tmp/ray_streaming_java_test_output/testng-results.xml +else + echo "Test result file doesn't exist" +fi + +# exit_code == 2 means there are skipped tests. +if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then + if [ -d "/tmp/ray_streaming_java_test_output/" ] ; then + echo "all test output" + for f in /tmp/ray_streaming_java_test_output/*.{log,xml}; do + if [ -f "$f" ]; then + echo "Cat file $f" + cat "$f" + elif [[ -d $f ]]; then + echo "$f is a directory" + fi + done + fi + for f in /home/travis/build/ray-project/ray/hs_err*log; do + if [ -f "$f" ]; then + echo "Cat file $f" + cat "$f" + fi + done + exit $exit_code +fi + +echo "Testing maven install." +cd "$ROOT_DIR"/../../java +echo "build ray maven deps" +bazel build gen_maven_deps +echo "maven install ray" +mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dcheckstyle.skip +cd "$ROOT_DIR" +echo "maven install ray streaming" +mvn -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN clean install -DskipTests -Dcheckstyle.skip diff --git a/streaming/java/testng.xml b/streaming/java/testng.xml new file mode 100644 index 00000000..3eaadbbd --- /dev/null +++ b/streaming/java/testng.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/streaming/python/__init__.pxd b/streaming/python/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/streaming/python/__init__.py b/streaming/python/__init__.py new file mode 100644 index 00000000..2eb090c6 --- /dev/null +++ b/streaming/python/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +# Ray should be imported before streaming +import ray +from ray.streaming.context import StreamingContext + +__all__ = ['StreamingContext'] diff --git a/streaming/python/_streaming.pyx b/streaming/python/_streaming.pyx new file mode 100644 index 00000000..3d845ff3 --- /dev/null +++ b/streaming/python/_streaming.pyx @@ -0,0 +1,6 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +include "includes/transfer.pxi" diff --git a/streaming/python/collector.py b/streaming/python/collector.py new file mode 100644 index 00000000..1760900f --- /dev/null +++ b/streaming/python/collector.py @@ -0,0 +1,80 @@ +import logging +import typing +from abc import ABC, abstractmethod + +from ray import Language +from ray.actor import ActorHandle +from ray.streaming import function +from ray.streaming import message +from ray.streaming import partition +from ray.streaming.runtime import serialization +from ray.streaming.runtime.transfer import ChannelID, DataWriter + +logger = logging.getLogger(__name__) + + +class Collector(ABC): + """ + The collector that collects data from an upstream operator, + and emits data to downstream operators. + """ + + @abstractmethod + def collect(self, record): + pass + + +class CollectionCollector(Collector): + def __init__(self, collector_list): + self._collector_list = collector_list + + def collect(self, value): + for collector in self._collector_list: + collector.collect(message.Record(value)) + + +class OutputCollector(Collector): + def __init__(self, writer: DataWriter, channel_ids: typing.List[str], + target_actors: typing.List[ActorHandle], + partition_func: partition.Partition): + self._writer = writer + self._channel_ids = [ChannelID(id_str) for id_str in channel_ids] + self._target_languages = [] + for actor in target_actors: + if actor._ray_actor_language == Language.PYTHON: + self._target_languages.append(function.Language.PYTHON) + elif actor._ray_actor_language == Language.JAVA: + self._target_languages.append(function.Language.JAVA) + else: + raise Exception("Unsupported language {}" + .format(actor._ray_actor_language)) + self._partition_func = partition_func + self.python_serializer = serialization.PythonSerializer() + self.cross_lang_serializer = serialization.CrossLangSerializer() + logger.info( + "Create OutputCollector, channel_ids {}, partition_func {}".format( + channel_ids, partition_func)) + + def collect(self, record): + partitions = self._partition_func \ + .partition(record, len(self._channel_ids)) + python_buffer = None + cross_lang_buffer = None + for partition_index in partitions: + if self._target_languages[partition_index] == \ + function.Language.PYTHON: + # avoid repeated serialization + if python_buffer is None: + python_buffer = self.python_serializer.serialize(record) + self._writer.write( + self._channel_ids[partition_index], + bytes([serialization.PYTHON_TYPE_ID]) + python_buffer) + else: + # avoid repeated serialization + if cross_lang_buffer is None: + cross_lang_buffer = self.cross_lang_serializer.serialize( + record) + self._writer.write( + self._channel_ids[partition_index], + bytes([serialization.CROSS_LANG_TYPE_ID]) + + cross_lang_buffer) diff --git a/streaming/python/config.py b/streaming/python/config.py new file mode 100644 index 00000000..b80d49b2 --- /dev/null +++ b/streaming/python/config.py @@ -0,0 +1,62 @@ +class Config: + STREAMING_JOB_NAME = "streaming.job.name" + STREAMING_OP_NAME = "streaming.op_name" + STREAMING_WORKER_NAME = "streaming.worker_name" + # channel + CHANNEL_TYPE = "channel_type" + MEMORY_CHANNEL = "memory_channel" + NATIVE_CHANNEL = "native_channel" + CHANNEL_SIZE = "channel_size" + CHANNEL_SIZE_DEFAULT = 10**8 + # return from StreamingReader.getBundle if only empty message read in this + # interval. + TIMER_INTERVAL_MS = "timer_interval_ms" + READ_TIMEOUT_MS = "read_timeout_ms" + DEFAULT_READ_TIMEOUT_MS = "10" + STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity" + # write an empty message if there is no data to be written in this + # interval. + STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval" + + # operator type + OPERATOR_TYPE = "operator_type" + + # flow control + FLOW_CONTROL_TYPE = "streaming.flow_control_type" + WRITER_CONSUMED_STEP = "streaming.writer.consumed_step" + READER_CONSUMED_STEP = "streaming.reader.consumed_step" + + # state backend + CP_STATE_BACKEND_TYPE = "streaming.context-backend.type" + CP_STATE_BACKEND_MEMORY = "memory" + CP_STATE_BACKEND_LOCAL_FILE = "local_file" + CP_STATE_BACKEND_DEFAULT = CP_STATE_BACKEND_MEMORY + + # local disk + FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root" + FILE_STATE_ROOT_PATH_DEFAULT = "/tmp/ray_streaming_state" + + # checkpoint + JOB_WORKER_CONTEXT_KEY = "jobworker_context_" + + # reliability level + REQUEST_ROLLBACK_RETRY_TIMES = 3 + + # checkpoint prefix key + JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY = "jobwk_op_" + + +class ConfigHelper(object): + @staticmethod + def get_cp_local_file_root_dir(conf): + value = conf.get(Config.FILE_STATE_ROOT_PATH) + if value is not None: + return value + return Config.FILE_STATE_ROOT_PATH_DEFAULT + + @staticmethod + def get_cp_context_backend_type(conf): + value = conf.get(Config.CP_STATE_BACKEND_TYPE) + if value is not None: + return value + return Config.CP_STATE_BACKEND_DEFAULT diff --git a/streaming/python/context.py b/streaming/python/context.py new file mode 100644 index 00000000..074f8252 --- /dev/null +++ b/streaming/python/context.py @@ -0,0 +1,192 @@ +from abc import ABC, abstractmethod + +from ray.streaming.datastream import StreamSource +from ray.streaming.function import LocalFileSourceFunction +from ray.streaming.function import CollectionSourceFunction +from ray.streaming.function import SourceFunction +from ray.streaming.runtime.gateway_client import GatewayClient + + +class StreamingContext: + """ + Main entry point for ray streaming functionality. + A StreamingContext is also a wrapper of java + `io.ray.streaming.api.context.StreamingContext` + """ + + class Builder: + def __init__(self): + self._options = {} + + def option(self, key=None, value=None, conf=None): + """ + Sets a config option. Options set using this method are + automatically propagated to :class:`StreamingContext`'s own + configuration. + + Args: + key: a key name string for configuration property + value: a value string for configuration property + conf: multi key-value pairs as a dict + + Returns: + self + """ + if key is not None: + assert value is not None + self._options[key] = str(value) + if conf is not None: + for k, v in conf.items(): + self._options[k] = v + return self + + def build(self): + """ + Creates a StreamingContext based on the options set in this + builder. + """ + ctx = StreamingContext() + ctx._gateway_client.with_config(self._options) + return ctx + + def __init__(self): + self.__gateway_client = GatewayClient() + self._j_ctx = self._gateway_client.create_streaming_context() + + def source(self, source_func: SourceFunction): + """Create an input data stream with a SourceFunction + + Args: + source_func: the SourceFunction used to create the data stream + + Returns: + The data stream constructed from the source_func + """ + return StreamSource.build_source(self, source_func) + + def from_values(self, *values): + """Creates a data stream from values + + Args: + values: The elements to create the data stream from. + + Returns: + The data stream representing the given values + """ + return self.from_collection(values) + + def from_collection(self, values): + """Creates a data stream from the given non-empty collection. + + Args: + values: The collection of elements to create the data stream from. + + Returns: + The data stream representing the given collection. + """ + assert values, "values shouldn't be None or empty" + func = CollectionSourceFunction(values) + return self.source(func) + + def read_text_file(self, filename: str): + """Reads the given file line-by-line and creates a data stream that + contains a string with the contents of each such line.""" + func = LocalFileSourceFunction(filename) + return self.source(func) + + def submit(self, job_name: str): + """Submit job for execution. + + Args: + job_name: name of the job + + Returns: + An JobSubmissionResult future + """ + self._gateway_client.execute(job_name) + # TODO return a JobSubmissionResult future + + def execute(self, job_name: str): + """Execute the job. This method will block until job finished. + + Args: + job_name: name of the job + """ + # TODO support block to job finish + # job_submit_result = self.submit(job_name) + # job_submit_result.wait_finish() + raise Exception("Unsupported") + + @property + def _gateway_client(self): + return self.__gateway_client + + +class RuntimeContext(ABC): + @abstractmethod + def get_task_id(self): + """ + Returns: + Task id of the parallel task. + """ + pass + + @abstractmethod + def get_task_index(self): + """ + Gets the index of this parallel subtask. The index starts from 0 + and goes up to parallelism-1 (parallelism as returned by + `get_parallelism()`). + + Returns: + The index of the parallel subtask. + """ + pass + + @abstractmethod + def get_parallelism(self): + """ + Returns: + The parallelism with which the parallel task runs. + """ + pass + + @abstractmethod + def get_config(self): + """ + Returns: + The config with which the parallel task runs. + """ + pass + + @abstractmethod + def get_job_config(self): + """ + Returns: + The job config. + """ + pass + + +class RuntimeContextImpl(RuntimeContext): + def __init__(self, task_id, task_index, parallelism, **kargs): + self.task_id = task_id + self.task_index = task_index + self.parallelism = parallelism + self.config = kargs.get("config", {}) + self.job_config = kargs.get("job_config", {}) + + def get_task_id(self): + return self.task_id + + def get_task_index(self): + return self.task_index + + def get_parallelism(self): + return self.parallelism + + def get_config(self): + return self.config + + def get_job_config(self): + return self.job_config diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py new file mode 100644 index 00000000..91193be2 --- /dev/null +++ b/streaming/python/datastream.py @@ -0,0 +1,564 @@ +from abc import ABC, abstractmethod + +from ray.streaming import function +from ray.streaming import partition + + +class Stream(ABC): + """ + Abstract base class of all stream types. A Stream represents a stream of + elements of the same type. A Stream can be transformed into another Stream + by applying a transformation. + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + self.input_stream = input_stream + self._j_stream = j_stream + if streaming_context is None: + assert input_stream is not None + self.streaming_context = input_stream.streaming_context + else: + self.streaming_context = streaming_context + + def get_streaming_context(self): + return self.streaming_context + + def get_parallelism(self): + """ + Returns: + the parallelism of this transformation + """ + return self._gateway_client(). \ + call_method(self._j_stream, "getParallelism") + + def set_parallelism(self, parallelism: int): + """Sets the parallelism of this transformation + + Args: + parallelism: The new parallelism to set on this transformation + + Returns: + self + """ + self._gateway_client(). \ + call_method(self._j_stream, "setParallelism", parallelism) + return self + + def get_input_stream(self): + """ + Returns: + input stream of this stream + """ + return self.input_stream + + def get_id(self): + """ + Returns: + An unique id identifies this stream. + """ + return self._gateway_client(). \ + call_method(self._j_stream, "getId") + + def with_config(self, key=None, value=None, conf=None): + """Set stream config. + + Args: + key: a key name string for configuration property + value: a value string for configuration property + conf: multi key-value pairs as a dict + + Returns: + self + """ + if key is not None: + assert type(key) is str + assert type(value) is str + self._gateway_client(). \ + call_method(self._j_stream, "withConfig", key, value) + if conf is not None: + for k, v in conf.items(): + assert type(k) is str + assert type(v) is str + self._gateway_client(). \ + call_method(self._j_stream, "withConfig", conf) + return self + + def get_config(self): + """ + Returns: + A dict config for this stream + """ + return self._gateway_client().call_method(self._j_stream, "getConfig") + + @abstractmethod + def get_language(self): + pass + + def forward(self): + """Set the partition function of this {@link Stream} so that output + elements are forwarded to next operator locally.""" + self._gateway_client().call_method(self._j_stream, "forward") + return self + + def disable_chain(self): + """Disable chain for this stream so that it will be run in a separate + task.""" + self._gateway_client().call_method(self._j_stream, "disableChain") + return self + + def _gateway_client(self): + return self.get_streaming_context()._gateway_client + + +class DataStream(Stream): + """ + Represents a stream of data which applies a transformation executed by + python. It's also a wrapper of java + `io.ray.streaming.python.stream.PythonDataStream` + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + super().__init__( + input_stream, j_stream, streaming_context=streaming_context) + + def get_language(self): + return function.Language.PYTHON + + def map(self, func): + """ + Applies a Map transformation on a :class:`DataStream`. + The transformation calls a :class:`ray.streaming.function.MapFunction` + for each element of the DataStream. + + Args: + func: The MapFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass + of MapFunction, it will be wrapped as SimpleMapFunction. + + Returns: + A new data stream transformed by the MapFunction. + """ + if not isinstance(func, function.MapFunction): + func = function.SimpleMapFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "map", j_func) + return DataStream(self, j_stream) + + def flat_map(self, func): + """ + Applies a FlatMap transformation on a :class:`DataStream`. The + transformation calls a :class:`ray.streaming.function.FlatMapFunction` + for each element of the DataStream. + Each FlatMapFunction call can return any number of elements including + none. + + Args: + func: The FlatMapFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass + of FlatMapFunction, it will be wrapped as SimpleFlatMapFunction. + + Returns: + The transformed DataStream + """ + if not isinstance(func, function.FlatMapFunction): + func = function.SimpleFlatMapFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "flatMap", j_func) + return DataStream(self, j_stream) + + def filter(self, func): + """ + Applies a Filter transformation on a :class:`DataStream`. The + transformation calls a :class:`ray.streaming.function.FilterFunction` + for each element of the DataStream. + DataStream and retains only those element for which the function + returns True. + + Args: + func: The FilterFunction that is called for each element of the + DataStream. If `func` is a python function instead of a subclass of + FilterFunction, it will be wrapped as SimpleFilterFunction. + + Returns: + The filtered DataStream + """ + if not isinstance(func, function.FilterFunction): + func = function.SimpleFilterFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "filter", j_func) + return DataStream(self, j_stream) + + def union(self, *streams): + """Apply union transformations to this stream by merging data stream + outputs of the same type with each other. + + Args: + *streams: The DataStreams to union output with. + + Returns: + A new UnionStream. + """ + assert len(streams) >= 1, "Need at least one stream to union with" + j_streams = [s._j_stream for s in streams] + j_stream = self._gateway_client().union(self._j_stream, *j_streams) + return UnionStream(self, j_stream) + + def key_by(self, func): + """ + Creates a new :class:`KeyDataStream` that uses the provided key to + partition data stream by key. + + Args: + func: The KeyFunction that is used for extracting the key for + partitioning. If `func` is a python function instead of a subclass + of KeyFunction, it will be wrapped as SimpleKeyFunction. + + Returns: + A KeyDataStream + """ + self._check_partition_call() + if not isinstance(func, function.KeyFunction): + func = function.SimpleKeyFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "keyBy", j_func) + return KeyDataStream(self, j_stream) + + def broadcast(self): + """ + Sets the partitioning of the :class:`DataStream` so that the output + elements are broadcast to every parallel instance of the next + operation. + + Returns: + The DataStream with broadcast partitioning set. + """ + self._check_partition_call() + self._gateway_client().call_method(self._j_stream, "broadcast") + return self + + def partition_by(self, partition_func): + """ + Sets the partitioning of the :class:`DataStream` so that the elements + of stream are partitioned by specified partition function. + + Args: + partition_func: partition function. + If `func` is a python function instead of a subclass of Partition, + it will be wrapped as SimplePartition. + + Returns: + The DataStream with specified partitioning set. + """ + self._check_partition_call() + if not isinstance(partition_func, partition.Partition): + partition_func = partition.SimplePartition(partition_func) + j_partition = self._gateway_client().create_py_func( + partition.serialize(partition_func)) + self._gateway_client(). \ + call_method(self._j_stream, "partitionBy", j_partition) + return self + + def _check_partition_call(self): + """ + If parent stream is a java stream, we can't call partition related + methods in the python stream + """ + if self.input_stream is not None and \ + self.input_stream.get_language() == function.Language.JAVA: + raise Exception("Partition related methods can't be called on a " + "python stream if parent stream is a java stream.") + + def sink(self, func): + """ + Create a StreamSink with the given sink. + + Args: + func: sink function. + + Returns: + a StreamSink. + """ + if not isinstance(func, function.SinkFunction): + func = function.SimpleSinkFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "sink", j_func) + return StreamSink(self, j_stream, func) + + def as_java_stream(self): + """ + Convert this stream as a java JavaDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asJavaStream") + return JavaDataStream(self, j_stream) + + +class JavaDataStream(Stream): + """ + Represents a stream of data which applies a transformation executed by + java. It's also a wrapper of java + `io.ray.streaming.api.stream.DataStream` + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + super().__init__( + input_stream, j_stream, streaming_context=streaming_context) + + def get_language(self): + return function.Language.JAVA + + def map(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.map""" + return JavaDataStream(self, self._unary_call("map", java_func_class)) + + def flat_map(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.flatMap""" + return JavaDataStream(self, self._unary_call("flatMap", + java_func_class)) + + def filter(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.filter""" + return JavaDataStream(self, self._unary_call("filter", + java_func_class)) + + def union(self, *streams): + """See io.ray.streaming.api.stream.DataStream.union""" + assert len(streams) >= 1, "Need at least one stream to union with" + j_streams = [s._j_stream for s in streams] + j_stream = self._gateway_client().union(self._j_stream, *j_streams) + return JavaUnionStream(self, j_stream) + + def key_by(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.keyBy""" + self._check_partition_call() + return JavaKeyDataStream(self, + self._unary_call("keyBy", java_func_class)) + + def broadcast(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.broadcast""" + self._check_partition_call() + return JavaDataStream(self, + self._unary_call("broadcast", java_func_class)) + + def partition_by(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.partitionBy""" + self._check_partition_call() + return JavaDataStream(self, + self._unary_call("partitionBy", java_func_class)) + + def sink(self, java_func_class): + """See io.ray.streaming.api.stream.DataStream.sink""" + return JavaStreamSink(self, self._unary_call("sink", java_func_class)) + + def as_python_stream(self): + """ + Convert this stream as a python DataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asPythonStream") + return DataStream(self, j_stream) + + def _check_partition_call(self): + """ + If parent stream is a python stream, we can't call partition related + methods in the java stream + """ + if self.input_stream is not None and \ + self.input_stream.get_language() == function.Language.PYTHON: + raise Exception("Partition related methods can't be called on a" + "java stream if parent stream is a python stream.") + + def _unary_call(self, func_name, java_func_class): + j_func = self._gateway_client().new_instance(java_func_class) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, func_name, j_func) + return j_stream + + +class KeyDataStream(DataStream): + """Represents a DataStream returned by a key-by operation. + Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def reduce(self, func): + """ + Applies a reduce transformation on the grouped data stream grouped on + by the given key function. + The :class:`ray.streaming.function.ReduceFunction` will receive input + values based on the key value. Only input values with the same key will + go to the same reducer. + + Args: + func: The ReduceFunction that will be called for every element of + the input values with the same key. If `func` is a python function + instead of a subclass of ReduceFunction, it will be wrapped as + SimpleReduceFunction. + + Returns: + A transformed DataStream. + """ + if not isinstance(func, function.ReduceFunction): + func = function.SimpleReduceFunction(func) + j_func = self._gateway_client().create_py_func( + function.serialize(func)) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "reduce", j_func) + return DataStream(self, j_stream) + + def as_java_stream(self): + """ + Convert this stream as a java KeyDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asJavaStream") + return JavaKeyDataStream(self, j_stream) + + +class JavaKeyDataStream(JavaDataStream): + """ + Represents a DataStream returned by a key-by operation in java. + Wrapper of io.ray.streaming.api.stream.KeyDataStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def reduce(self, java_func_class): + """See io.ray.streaming.api.stream.KeyDataStream.reduce""" + return JavaDataStream(self, + super()._unary_call("reduce", java_func_class)) + + def as_python_stream(self): + """ + Convert this stream as a python KeyDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asPythonStream") + return KeyDataStream(self, j_stream) + + +class UnionStream(DataStream): + """Represents a union stream. + Wrapper of java io.ray.streaming.python.stream.PythonUnionStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.PYTHON + + +class JavaUnionStream(JavaDataStream): + """Represents a java union stream. + Wrapper of java io.ray.streaming.api.stream.UnionStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.JAVA + + +class StreamSource(DataStream): + """Represents a source of the DataStream. + Wrapper of java io.ray.streaming.python.stream.PythonStreamSource + """ + + def __init__(self, j_stream, streaming_context, source_func): + super().__init__(None, j_stream, streaming_context=streaming_context) + self.source_func = source_func + + def get_language(self): + return function.Language.PYTHON + + @staticmethod + def build_source(streaming_context, func): + """Build a StreamSource source from a source function. + Args: + streaming_context: Stream context + func: A instance of `SourceFunction` + Returns: + A StreamSource + """ + j_stream = streaming_context._gateway_client. \ + create_py_stream_source(function.serialize(func)) + return StreamSource(j_stream, streaming_context, func) + + +class JavaStreamSource(JavaDataStream): + """Represents a source of the java DataStream. + Wrapper of java io.ray.streaming.api.stream.DataStreamSource + """ + + def __init__(self, j_stream, streaming_context): + super().__init__(None, j_stream, streaming_context=streaming_context) + + def get_language(self): + return function.Language.JAVA + + @staticmethod + def build_source(streaming_context, java_source_func_class): + """Build a java StreamSource source from a java source function. + Args: + streaming_context: Stream context + java_source_func_class: qualified class name of java SourceFunction + Returns: + A java StreamSource + """ + j_func = streaming_context._gateway_client() \ + .new_instance(java_source_func_class) + j_stream = streaming_context._gateway_client() \ + .call_function("io.ray.streaming.api.stream.DataStreamSource" + "fromSource", streaming_context._j_ctx, j_func) + return JavaStreamSource(j_stream, streaming_context) + + +class StreamSink(Stream): + """Represents a sink of the DataStream. + Wrapper of java io.ray.streaming.python.stream.PythonStreamSink + """ + + def __init__(self, input_stream, j_stream, func): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.PYTHON + + +class JavaStreamSink(Stream): + """Represents a sink of the java DataStream. + Wrapper of java io.ray.streaming.api.stream.StreamSink + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.JAVA diff --git a/streaming/python/examples/articles.txt b/streaming/python/examples/articles.txt new file mode 100644 index 00000000..0bb455fd --- /dev/null +++ b/streaming/python/examples/articles.txt @@ -0,0 +1,8 @@ +New York City +Berlin +London +Paris +United States +Germany +France +United Kingdom diff --git a/streaming/python/examples/wordcount.py b/streaming/python/examples/wordcount.py new file mode 100644 index 00000000..d10782b5 --- /dev/null +++ b/streaming/python/examples/wordcount.py @@ -0,0 +1,88 @@ +import argparse +import logging +import sys +import time + +import ray +import wikipedia +from ray.streaming import StreamingContext +from ray.streaming.config import Config + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--titles-file", + required=True, + help="the file containing the wikipedia titles to lookup") + + +# A custom data source that reads articles from wikipedia +# Custom data sources need to implement a get_next() method +# that returns the next data element, in this case sentences +class Wikipedia: + def __init__(self, title_file): + # Titles in this file will be as queries + self.title_file = title_file + self.title_reader = iter(list(open(self.title_file, "r").readlines())) + self.done = False + self.article_done = True + self.sentences = iter([]) + + # Returns next sentence from a wikipedia article + def get_next(self): + if self.done: + return None # Source exhausted + while True: + if self.article_done: + try: # Try next title + next_title = next(self.title_reader) + except StopIteration: + self.done = True # Source exhausted + return None + # Get next article + logger.debug("Next article: {}".format(next_title)) + article = wikipedia.page(next_title).content + # Split article in sentences + self.sentences = iter(article.split(".")) + self.article_done = False + try: # Try next sentence + sentence = next(self.sentences) + logger.debug("Next sentence: {}".format(sentence)) + return sentence + except StopIteration: + self.article_done = True + + +# Splits input line into words and +# outputs records of the form (word,1) +def splitter(line): + return [(word, 1) for word in line.split()] + + +if __name__ == "__main__": + # Get program parameters + args = parser.parse_args() + titles_file = str(args.titles_file) + + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + + ctx = StreamingContext.Builder() \ + .option(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) \ + .build() + # A Ray streaming environment with the default configuration + ctx.set_parallelism(1) # Each operator will be executed by two actors + + # Reads articles from wikipedia, splits them in words, + # shuffles words, and counts the occurrences of each word. + stream = ctx.source(Wikipedia(titles_file)) \ + .flat_map(splitter) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .sink(print) + start = time.time() + ctx.execute("wordcount") + end = time.time() + logger.info("Elapsed time: {} secs".format(end - start)) diff --git a/streaming/python/function.py b/streaming/python/function.py new file mode 100644 index 00000000..e8a26feb --- /dev/null +++ b/streaming/python/function.py @@ -0,0 +1,350 @@ +import enum +import importlib +import inspect +import sys +from abc import ABC, abstractmethod + +from ray import cloudpickle +from ray.streaming.runtime import gateway_client + + +class Language(enum.Enum): + JAVA = 0 + PYTHON = 1 + + +class Function(ABC): + """The base interface for all user-defined functions.""" + + def open(self, runtime_context): + pass + + def close(self): + pass + + def save_checkpoint(self): + pass + + def load_checkpoint(self, checkpoint_obj): + pass + + +class EmptyFunction(Function): + """Default function which does nothing""" + + def open(self, runtime_context): + pass + + def close(self): + pass + + +class SourceContext(ABC): + """ + Interface that source functions use to emit elements, and possibly + watermarks.""" + + @abstractmethod + def collect(self, element): + """Emits one element from the source, without attaching a timestamp.""" + pass + + +class SourceFunction(Function): + """Interface of Source functions.""" + + @abstractmethod + def init(self, parallel, index): + """ + Args: + parallel: parallelism of source function + index: task index of this function and goes up from 0 to + parallel-1. + """ + pass + + @abstractmethod + def fetch(self, ctx: SourceContext): + """Starts the source. Implementations can use the + :class:`SourceContext` to emit elements. + """ + pass + + def close(self): + pass + + +class MapFunction(Function): + """ + Base interface for Map functions. Map functions take elements and transform + them element wise. A Map function always produces a single result element + for each input element. + """ + + @abstractmethod + def map(self, value): + pass + + +class FlatMapFunction(Function): + """ + Base interface for flatMap functions. FlatMap functions take elements and + transform them into zero, one, or more elements. + """ + + @abstractmethod + def flat_map(self, value, collector): + """Takes an element from the input data set and transforms it into zero, + one, or more elements. + + Args: + value: The input value. + collector: The collector for returning result values. + """ + pass + + +class FilterFunction(Function): + """ + A filter function is a predicate applied individually to each record. + The predicate decides whether to keep the element, or to discard it. + """ + + @abstractmethod + def filter(self, value): + """The filter function that evaluates the predicate. + + Args: + value: The value to be filtered. + + Returns: + True for values that should be retained, false for values to be + filtered out. + """ + pass + + +class KeyFunction(Function): + """ + A key function is extractor which takes an object and returns the + deterministic key for that object. + """ + + @abstractmethod + def key_by(self, value): + """User-defined function that deterministically extracts the key from + an object. + + Args: + value: The object to get the key from. + + Returns: + The extracted key. + """ + pass + + +class ReduceFunction(Function): + """ + Base interface for Reduce functions. Reduce functions combine groups of + elements to a single value, by taking always two elements and combining + them into one. + """ + + @abstractmethod + def reduce(self, old_value, new_value): + """ + The core method of ReduceFunction, combining two values into one value + of the same type. The reduce function is consecutively applied to all + values of a group until only a single value remains. + + Args: + old_value: The old value to combine. + new_value: The new input value to combine. + + Returns: + The combined value of both values. + """ + pass + + +class SinkFunction(Function): + """Interface for implementing user defined sink functionality.""" + + @abstractmethod + def sink(self, value): + """Writes the given value to the sink. This function is called for + every record.""" + pass + + +class CollectionSourceFunction(SourceFunction): + def __init__(self, values): + self.values = values + + def init(self, parallel, index): + pass + + def fetch(self, ctx: SourceContext): + for v in self.values: + ctx.collect(v) + self.values = [] + + +class LocalFileSourceFunction(SourceFunction): + def __init__(self, filename): + self.filename = filename + self.done = False + + def init(self, parallel, index): + pass + + def fetch(self, ctx: SourceContext): + if self.done: + return + with open(self.filename, "r") as f: + line = f.readline() + while line != "": + ctx.collect(line[:-1]) + line = f.readline() + self.done = True + + +class SimpleMapFunction(MapFunction): + def __init__(self, func): + self.func = func + + def map(self, value): + return self.func(value) + + +class SimpleFlatMapFunction(FlatMapFunction): + """ + Wrap a python function as :class:`FlatMapFunction` + + >>> assert SimpleFlatMapFunction(lambda x: x.split()) + >>> def flat_func(x, collector): + ... for item in x.split(): + ... collector.collect(item) + >>> assert SimpleFlatMapFunction(flat_func) + """ + + def __init__(self, func): + """ + Args: + func: a python function which takes an element from input augment + and transforms it into zero, one, or more elements. + Or takes an element from input augment, and used provided collector + to collect zero, one, or more elements. + """ + self.func = func + self.process_func = None + sig = inspect.signature(func) + assert len(sig.parameters) <= 2, \ + "func should receive value [, collector] as arguments" + if len(sig.parameters) == 2: + + def process(value, collector): + func(value, collector) + + self.process_func = process + else: + + def process(value, collector): + for elem in func(value): + collector.collect(elem) + + self.process_func = process + + def flat_map(self, value, collector): + self.process_func(value, collector) + + +class SimpleFilterFunction(FilterFunction): + def __init__(self, func): + self.func = func + + def filter(self, value): + return self.func(value) + + +class SimpleKeyFunction(KeyFunction): + def __init__(self, func): + self.func = func + + def key_by(self, value): + return self.func(value) + + +class SimpleReduceFunction(ReduceFunction): + def __init__(self, func): + self.func = func + + def reduce(self, old_value, new_value): + return self.func(old_value, new_value) + + +class SimpleSinkFunction(SinkFunction): + def __init__(self, func): + self.func = func + + def sink(self, value): + return self.func(value) + + +def serialize(func: Function): + """Serialize a streaming :class:`Function`""" + return cloudpickle.dumps(func) + + +def deserialize(func_bytes): + """Deserialize a binary function serialized by `serialize` method.""" + return cloudpickle.loads(func_bytes) + + +def load_function(descriptor_func_bytes: bytes): + """ + Deserialize `descriptor_func_bytes` to get function info, then + get or load streaming function. + Note that this function must be kept in sync with + `io.ray.streaming.runtime.python.GraphPbBuilder.serializeFunction` + + Args: + descriptor_func_bytes: serialized function info + + Returns: + a streaming function + """ + assert len(descriptor_func_bytes) > 0 + function_bytes, module_name, function_name, function_interface \ + = gateway_client.deserialize(descriptor_func_bytes) + if function_bytes: + return deserialize(function_bytes) + else: + assert module_name + assert function_interface + function_interface = getattr(sys.modules[__name__], function_interface) + mod = importlib.import_module(module_name) + assert function_name + func = getattr(mod, function_name) + # If func is a python function, user function is a simple python + # function, which will be wrapped as a SimpleXXXFunction. + # If func is a python class, user function is a sub class + # of XXXFunction. + if inspect.isfunction(func): + simple_func_class = _get_simple_function_class(function_interface) + return simple_func_class(func) + else: + assert issubclass(func, function_interface) + return func() + + +def _get_simple_function_class(function_interface): + """Get the wrapper function for the given `function_interface`.""" + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, function_interface): + if obj is not function_interface and obj.__name__.startswith( + "Simple"): + return obj + raise Exception( + "SimpleFunction for {} doesn't exist".format(function_interface)) diff --git a/streaming/python/includes/__init__.pxd b/streaming/python/includes/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd new file mode 100644 index 00000000..899e5169 --- /dev/null +++ b/streaming/python/includes/libstreaming.pxd @@ -0,0 +1,201 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 +# flake8: noqa + +from libc.stdint cimport * +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector as c_vector +from libcpp.list cimport list as c_list +from cpython cimport PyObject +cimport cpython +from libcpp.unordered_map cimport unordered_map as c_unordered_map +from cython.operator cimport dereference, postincrement + + +cdef inline object PyObject_to_object(PyObject* o): + # Cast to "object" increments reference count + cdef object result = o + cpython.Py_DECREF(result) + return result + +from ray.includes.common cimport ( + CLanguage, + CRayObject, + CRayStatus, + CRayFunction +) + +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CTaskID, + CObjectID, +) + +cdef extern from "common/status.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus": + pass + cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK" + cdef CStreamingStatus StatusReconstructTimeOut "ray::streaming::StreamingStatus::ReconstructTimeOut" + cdef CStreamingStatus StatusQueueIdNotFound "ray::streaming::StreamingStatus::QueueIdNotFound" + cdef CStreamingStatus StatusResubscribeFailed "ray::streaming::StreamingStatus::ResubscribeFailed" + cdef CStreamingStatus StatusEmptyRingBuffer "ray::streaming::StreamingStatus::EmptyRingBuffer" + cdef CStreamingStatus StatusFullChannel "ray::streaming::StreamingStatus::FullChannel" + cdef CStreamingStatus StatusNoSuchItem "ray::streaming::StreamingStatus::NoSuchItem" + cdef CStreamingStatus StatusInitQueueFailed "ray::streaming::StreamingStatus::InitQueueFailed" + cdef CStreamingStatus StatusGetBundleTimeOut "ray::streaming::StreamingStatus::GetBundleTimeOut" + cdef CStreamingStatus StatusSkipSendEmptyMessage "ray::streaming::StreamingStatus::SkipSendEmptyMessage" + cdef CStreamingStatus StatusInterrupted "ray::streaming::StreamingStatus::Interrupted" + cdef CStreamingStatus StatusWaitQueueTimeOut "ray::streaming::StreamingStatus::WaitQueueTimeOut" + cdef CStreamingStatus StatusOutOfMemory "ray::streaming::StreamingStatus::OutOfMemory" + cdef CStreamingStatus StatusInvalid "ray::streaming::StreamingStatus::Invalid" + cdef CStreamingStatus StatusUnknownError "ray::streaming::StreamingStatus::UnknownError" + cdef CStreamingStatus StatusTailStatus "ray::streaming::StreamingStatus::TailStatus" + + cdef cppclass CStreamingCommon "ray::streaming::StreamingCommon": + void SetConfig(const uint8_t *, uint32_t size) + + +cdef extern from "runtime_context.h" namespace "ray::streaming" nogil: + cdef cppclass CRuntimeContext "ray::streaming::RuntimeContext": + CRuntimeContext() + void SetConfig(const uint8_t *data, uint32_t size) + inline void MarkMockTest() + inline c_bool IsMockTest() + +cdef extern from "message/message.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingMessageType "ray::streaming::StreamingMessageType": + pass + cdef CStreamingMessageType MessageTypeBarrier "ray::streaming::StreamingMessageType::Barrier" + cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message" + cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage": + inline uint8_t *RawData() const + inline uint8_t *Payload() const + inline uint32_t PayloadSize() const + inline uint32_t GetDataSize() const + inline CStreamingMessageType GetMessageType() const + inline uint64_t GetMessageId() const + @staticmethod + inline void GetBarrierIdFromRawData(const uint8_t *data, + CStreamingBarrierHeader *barrier_header) + cdef struct CStreamingBarrierHeader "ray::streaming::StreamingBarrierHeader": + CStreamingBarrierType barrier_type; + uint64_t barrier_id; + cdef cppclass CStreamingBarrierType "ray::streaming::StreamingBarrierType": + pass + cdef uint32_t kMessageHeaderSize; + cdef uint32_t kBarrierHeaderSize; + +cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType": + pass + cdef CStreamingMessageBundleType BundleTypeEmpty "ray::streaming::StreamingMessageBundleType::Empty" + cdef CStreamingMessageBundleType BundleTypeBarrier "ray::streaming::StreamingMessageBundleType::Barrier" + cdef CStreamingMessageBundleType BundleTypeBundle "ray::streaming::StreamingMessageBundleType::Bundle" + + cdef cppclass CStreamingMessageBundleMeta "ray::streaming::StreamingMessageBundleMeta": + CStreamingMessageBundleMeta() + inline uint64_t GetMessageBundleTs() const + inline uint64_t GetLastMessageId() const + inline uint32_t GetMessageListSize() const + inline CStreamingMessageBundleType GetBundleType() const + inline c_bool IsBarrier() + inline c_bool IsBundle() + + ctypedef shared_ptr[CStreamingMessageBundleMeta] CStreamingMessageBundleMetaPtr + uint32_t kMessageBundleHeaderSize "ray::streaming::kMessageBundleHeaderSize" + cdef cppclass CStreamingMessageBundle "ray::streaming::StreamingMessageBundle"(CStreamingMessageBundleMeta): + @staticmethod + void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums, + c_list[shared_ptr[CStreamingMessage]] &msg_list); + +cdef extern from "channel/channel.h" namespace "ray::streaming" nogil: + cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter": + CChannelCreationParameter() + CActorID actor_id; + shared_ptr[CRayFunction] async_function; + shared_ptr[CRayFunction] sync_function; + + cdef struct CStreamingQueueInfo "ray::streaming::StreamingQueueInfo": + uint64_t first_seq_id; + uint64_t last_message_id; + uint64_t target_message_id; + uint64_t consumed_message_id; + + cdef struct CConsumerChannelInfo "ray::streaming::ConsumerChannelInfo": + CObjectID channel_id; + uint64_t current_message_id; + uint64_t barrier_id; + uint64_t partial_barrier_id; + CStreamingQueueInfo queue_info; + uint64_t last_queue_item_delay; + uint64_t last_queue_item_latency; + uint64_t last_queue_target_diff; + uint64_t get_queue_item_times; + uint64_t notify_cnt; + CChannelCreationParameter parameter; + + cdef enum CTransferCreationStatus "ray::streaming::TransferCreationStatus": + FreshStarted = 0 + PullOk = 1 + Timeout = 2 + DataLost = 3 + Invalid = 999 + + +cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: + cdef cppclass CReaderClient "ray::streaming::ReaderClient": + CReaderClient() + void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer); + shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); + + cdef cppclass CWriterClient "ray::streaming::WriterClient": + CWriterClient() + void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer); + shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); + + +cdef extern from "data_reader.h" namespace "ray::streaming" nogil: + cdef cppclass CDataBundle "ray::streaming::DataBundle": + uint8_t *data + uint32_t data_size + CObjectID c_from "from" + uint64_t seq_id + CStreamingMessageBundleMetaPtr meta + + cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon): + CDataReader(shared_ptr[CRuntimeContext] &runtime_context) + void Init(const c_vector[CObjectID] &input_ids, + const c_vector[CChannelCreationParameter] ¶ms, + const c_vector[uint64_t] &msg_ids, + c_vector[CTransferCreationStatus] &creation_status, + int64_t timer_interval); + CStreamingStatus GetBundle(const uint32_t timeout_ms, + shared_ptr[CDataBundle] &message) + void GetOffsetInfo(c_unordered_map[CObjectID, CConsumerChannelInfo] *&offset_map); + void Stop() + + +cdef extern from "data_writer.h" namespace "ray::streaming" nogil: + cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon): + CDataWriter(shared_ptr[CRuntimeContext] &runtime_context) + CStreamingStatus Init(const c_vector[CObjectID] &channel_ids, + const c_vector[CChannelCreationParameter] ¶ms, + const c_vector[uint64_t] &message_ids, + const c_vector[uint64_t] &queue_size_vec); + long WriteMessageToBufferRing( + const CObjectID &q_id, uint8_t *data, uint32_t data_size) + void BroadcastBarrier(uint64_t checkpoint_id, const uint8_t *data, uint32_t data_size) + void GetChannelOffset(c_vector[uint64_t] &result) + void ClearCheckpoint(uint64_t checkpoint_id) + void Run() + void Stop() + + +cdef extern from "ray/common/buffer.h" nogil: + cdef cppclass CLocalMemoryBuffer "ray::LocalMemoryBuffer": + uint8_t *Data() const + size_t Size() const diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi new file mode 100644 index 00000000..d7beef5e --- /dev/null +++ b/streaming/python/includes/transfer.pxi @@ -0,0 +1,397 @@ +# flake8: noqa + +from libc.stdint cimport * +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.list cimport list as c_list +from libcpp.unordered_map cimport unordered_map as c_unordered_map +from cython.operator cimport dereference, postincrement + +from ray.includes.common cimport ( + CRayFunction, + LANGUAGE_PYTHON, + LANGUAGE_JAVA, + CBuffer +) + +from ray.includes.unique_ids cimport ( + CActorID, + CObjectID +) +from ray._raylet cimport ( + Buffer, + ActorID, + ObjectRef, + FunctionDescriptor, +) + +cimport ray.streaming.includes.libstreaming as libstreaming +from ray.streaming.includes.libstreaming cimport ( + CStreamingStatus, + CStreamingMessage, + CStreamingMessageBundle, + CRuntimeContext, + CDataBundle, + CDataWriter, + CDataReader, + CReaderClient, + CWriterClient, + CLocalMemoryBuffer, + CChannelCreationParameter, + CTransferCreationStatus, + CConsumerChannelInfo, + CStreamingBarrierHeader, + kBarrierHeaderSize, +) +from ray._raylet import JavaFunctionDescriptor + +import logging + + +channel_logger = logging.getLogger(__name__) + +cdef class ChannelCreationParameter: + cdef: + CChannelCreationParameter parameter + + def __cinit__(self, ActorID actor_id, FunctionDescriptor async_func, FunctionDescriptor sync_func): + cdef: + shared_ptr[CRayFunction] async_func_ptr + shared_ptr[CRayFunction] sync_func_ptr + self.parameter = CChannelCreationParameter() + self.parameter.actor_id = (actor_id).data + if isinstance(async_func, JavaFunctionDescriptor): + self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_JAVA, async_func.descriptor) + else: + self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_PYTHON, async_func.descriptor) + if isinstance(sync_func, JavaFunctionDescriptor): + self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_JAVA, sync_func.descriptor) + else: + self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_PYTHON, sync_func.descriptor) + + cdef CChannelCreationParameter get_parameter(self): + return self.parameter + +cdef class ReaderClient: + cdef: + CReaderClient *client + + def __cinit__(self): + self.client = new CReaderClient() + + def __dealloc__(self): + del self.client + self.client = NULL + + def on_reader_message(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + with nogil: + self.client.OnReaderMessage(local_buf) + + def on_reader_message_sync(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + shared_ptr[CLocalMemoryBuffer] result_buffer + with nogil: + result_buffer = self.client.OnReaderMessageSync(local_buf) + return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) + + +cdef class WriterClient: + cdef: + CWriterClient * client + + def __cinit__(self): + self.client = new CWriterClient() + + def __dealloc__(self): + del self.client + self.client = NULL + + def on_writer_message(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + with nogil: + self.client.OnWriterMessage(local_buf) + + def on_writer_message_sync(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + shared_ptr[CLocalMemoryBuffer] result_buffer + with nogil: + result_buffer = self.client.OnWriterMessageSync(local_buf) + return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) + + +cdef class DataWriter: + cdef: + CDataWriter *writer + + def __init__(self): + raise Exception("use create() to create DataWriter") + + @staticmethod + def create(list py_output_channels, + list output_creation_parameters: list[ChannelCreationParameter], + uint64_t queue_size, + list py_msg_ids, + bytes config_bytes, + c_bool is_mock): + cdef: + c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels) + c_vector[CChannelCreationParameter] initial_parameters + c_vector[uint64_t] msg_ids + CDataWriter *c_writer + ChannelCreationParameter parameter + cdef const unsigned char[:] config_data + for param in output_creation_parameters: + parameter = param + initial_parameters.push_back(parameter.get_parameter()) + for py_msg_id in py_msg_ids: + msg_ids.push_back(py_msg_id) + + cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() + if is_mock: + ctx.get().MarkMockTest() + if config_bytes: + config_data = config_bytes + channel_logger.info("DataWriter load config, config bytes size: %s", config_data.nbytes) + ctx.get().SetConfig((&config_data[0]), config_data.nbytes) + c_writer = new CDataWriter(ctx) + cdef: + c_vector[CObjectID] remain_id_vec + c_vector[uint64_t] queue_size_vec + for i in range(channel_ids.size()): + queue_size_vec.push_back(queue_size) + cdef CStreamingStatus status = c_writer.Init(channel_ids, initial_parameters, msg_ids, queue_size_vec) + if remain_id_vec.size() != 0: + channel_logger.warning("failed queue amounts => %s", remain_id_vec.size()) + if status != libstreaming.StatusOK: + msg = "initialize writer failed, status={}".format(status) + channel_logger.error(msg) + del c_writer + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec)) + + c_writer.Run() + channel_logger.info("create native writer succeed") + cdef DataWriter writer = DataWriter.__new__(DataWriter) + writer.writer = c_writer + return writer + + def __dealloc__(self): + if self.writer != NULL: + del self.writer + channel_logger.info("deleted DataWriter") + self.writer = NULL + + def write(self, ObjectRef qid, const unsigned char[:] value): + """support zero-copy bytes, byte array, array of unsigned char""" + cdef: + CObjectID native_id = qid.data + uint64_t msg_id + uint8_t *data = (&value[0]) + uint32_t size = value.nbytes + with nogil: + msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size) + return msg_id + + def broadcast_barrier(self, uint64_t checkpoint_id, const unsigned char[:] value): + cdef: + uint8_t *data = (&value[0]) + uint32_t size = value.nbytes + with nogil: + self.writer.BroadcastBarrier(checkpoint_id, data, size) + + def get_output_checkpoints(self): + cdef: + c_vector[uint64_t] results + self.writer.GetChannelOffset(results) + return results + + def clear_checkpoint(self, checkpoint_id): + cdef: + uint64_t c_checkpoint_id = checkpoint_id + with nogil: + self.writer.ClearCheckpoint(c_checkpoint_id) + + def stop(self): + self.writer.Stop() + channel_logger.info("stopped DataWriter") + + +cdef class DataReader: + cdef: + CDataReader *reader + readonly bytes meta + readonly bytes data + + def __init__(self): + raise Exception("use create() to create DataReader") + + @staticmethod + def create(list py_input_queues, + list input_creation_parameters: list[ChannelCreationParameter], + list py_msg_ids, + int64_t timer_interval, + bytes config_bytes, + c_bool is_mock): + cdef: + c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) + c_vector[CChannelCreationParameter] initial_parameters + c_vector[uint64_t] msg_ids + c_vector[CTransferCreationStatus] c_creation_status + CDataReader *c_reader + ChannelCreationParameter parameter + cdef const unsigned char[:] config_data + for param in input_creation_parameters: + parameter = param + initial_parameters.push_back(parameter.get_parameter()) + + for py_msg_id in py_msg_ids: + msg_ids.push_back(py_msg_id) + cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() + if config_bytes: + config_data = config_bytes + channel_logger.info("DataReader load config, config bytes size: %s", config_data.nbytes) + ctx.get().SetConfig((&(config_data[0])), config_data.nbytes) + if is_mock: + ctx.get().MarkMockTest() + c_reader = new CDataReader(ctx) + c_reader.Init(queue_id_vec, initial_parameters, msg_ids, c_creation_status, timer_interval) + + creation_status_map = {} + if not c_creation_status.empty(): + for i in range(queue_id_vec.size()): + k = queue_id_vec[i].Binary() + v = c_creation_status[i] + creation_status_map[k] = v + + channel_logger.info("create native reader succeed") + cdef DataReader reader = DataReader.__new__(DataReader) + reader.reader = c_reader + return reader, creation_status_map + + def __dealloc__(self): + if self.reader != NULL: + del self.reader + channel_logger.info("deleted DataReader") + self.reader = NULL + + def read(self, uint32_t timeout_millis): + cdef: + shared_ptr[CDataBundle] bundle + CStreamingStatus status + with nogil: + status = self.reader.GetBundle(timeout_millis, bundle) + if status != libstreaming.StatusOK: + if status == libstreaming.StatusInterrupted: + # avoid cyclic import + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInterruptException("reader interrupted") + elif status == libstreaming.StatusInitQueueFailed: + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInitException("init channel failed") + elif status == libstreaming.StatusGetBundleTimeOut: + return [] + else: + raise Exception("no such status " + str(status)) + cdef: + uint32_t msg_nums + CObjectID queue_id = bundle.get().c_from + c_list[shared_ptr[CStreamingMessage]] msg_list + list msgs = [] + uint64_t timestamp + uint64_t msg_id + c_unordered_map[CObjectID, CConsumerChannelInfo] *offset_map = NULL + shared_ptr[CStreamingMessage] barrier + CStreamingBarrierHeader barrier_header + c_unordered_map[CObjectID, CConsumerChannelInfo].iterator it + + cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) + # avoid cyclic import + from ray.streaming.runtime.transfer import DataMessage + if bundle_type == libstreaming.BundleTypeBundle: + msg_nums = bundle.get().meta.get().GetMessageListSize() + CStreamingMessageBundle.GetMessageListFromRawData( + bundle.get().data + libstreaming.kMessageBundleHeaderSize, + bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, + msg_nums, + msg_list) + timestamp = bundle.get().meta.get().GetMessageBundleTs() + for msg in msg_list: + msg_bytes = msg.get().Payload()[:msg.get().PayloadSize()] + qid_bytes = queue_id.Binary() + msg_id = msg.get().GetMessageId() + msgs.append( + DataMessage(msg_bytes, timestamp, msg_id, qid_bytes)) + return msgs + elif bundle_type == libstreaming.BundleTypeEmpty: + timestamp = bundle.get().meta.get().GetMessageBundleTs() + msg_id = bundle.get().meta.get().GetLastMessageId() + return [DataMessage(None, timestamp, msg_id, queue_id.Binary(), True)] + elif bundle.get().meta.get().IsBarrier(): + py_offset_map = {} + self.reader.GetOffsetInfo(offset_map) + it = offset_map.begin() + while it != offset_map.end(): + queue_id_bytes = dereference(it).first.Binary() + current_message_id = dereference(it).second.current_message_id + py_offset_map[queue_id_bytes] = current_message_id + postincrement(it) + msg_nums = bundle.get().meta.get().GetMessageListSize() + CStreamingMessageBundle.GetMessageListFromRawData( + bundle.get().data + libstreaming.kMessageBundleHeaderSize, + bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, + msg_nums, + msg_list) + timestamp = bundle.get().meta.get().GetMessageBundleTs() + barrier = msg_list.front() + msg_id = barrier.get().GetMessageId() + CStreamingMessage.GetBarrierIdFromRawData(barrier.get().Payload(), &barrier_header) + barrier_id = barrier_header.barrier_id + barrier_data = (barrier.get().Payload() + kBarrierHeaderSize)[ + :barrier.get().PayloadSize() - kBarrierHeaderSize] + barrier_type = barrier_header.barrier_type + py_queue_id = queue_id.Binary() + from ray.streaming.runtime.transfer import CheckpointBarrier + return [CheckpointBarrier( + barrier_data, timestamp, msg_id, py_queue_id, py_offset_map, + barrier_id, barrier_type)] + else: + raise Exception("Unsupported bundle type {}".format(bundle_type)) + + + def stop(self): + self.reader.Stop() + channel_logger.info("stopped DataReader") + + +cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *: + assert len(py_queue_ids) > 0 + cdef: + c_vector[CObjectID] queue_id_vec + c_string q_id_data + for q_id in py_queue_ids: + q_id_data = q_id + assert q_id_data.size() == CObjectID.Size(), f"{q_id_data.size()}, {CObjectID.Size()}" + obj_id = CObjectID.FromBinary(q_id_data) + queue_id_vec.push_back(obj_id) + return queue_id_vec + +cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec): + queues = [] + for obj_id in queue_id_vec: + queues.append(obj_id.Binary()) + return queues diff --git a/streaming/python/message.py b/streaming/python/message.py new file mode 100644 index 00000000..637cdf00 --- /dev/null +++ b/streaming/python/message.py @@ -0,0 +1,34 @@ +class Record: + """Data record in data stream""" + + def __init__(self, value): + self.value = value + self.stream = None + + def __repr__(self): + return "Record({})".format(self.value) + + def __eq__(self, other): + if type(self) is type(other): + return (self.stream, self.value) == (other.stream, other.value) + return False + + def __hash__(self): + return hash((self.stream, self.value)) + + +class KeyRecord(Record): + """Data record in a keyed data stream""" + + def __init__(self, key, value): + super().__init__(value) + self.key = key + + def __eq__(self, other): + if type(self) is type(other): + return (self.stream, self.key, self.value) ==\ + (other.stream, other.key, other.value) + return False + + def __hash__(self): + return hash((self.stream, self.key, self.value)) diff --git a/streaming/python/operator.py b/streaming/python/operator.py new file mode 100644 index 00000000..9163519d --- /dev/null +++ b/streaming/python/operator.py @@ -0,0 +1,399 @@ +import enum +import importlib +import logging +from abc import ABC, abstractmethod + +from ray.streaming import function +from ray.streaming import message +from ray.streaming.collector import Collector +from ray.streaming.collector import CollectionCollector +from ray.streaming.function import SourceFunction +from ray.streaming.runtime import gateway_client + +logger = logging.getLogger(__name__) + + +class OperatorType(enum.Enum): + SOURCE = 0 # Sources are where your program reads its input from + ONE_INPUT = 1 # This operator has one data stream as it's input stream. + TWO_INPUT = 2 # This operator has two data stream as it's input stream. + + +class Operator(ABC): + """ + Abstract base class for all operators. + An operator is used to run a :class:`function.Function`. + """ + + @abstractmethod + def open(self, collectors, runtime_context): + pass + + @abstractmethod + def finish(self): + pass + + @abstractmethod + def close(self): + pass + + @abstractmethod + def operator_type(self) -> OperatorType: + pass + + @abstractmethod + def save_checkpoint(self): + pass + + @abstractmethod + def load_checkpoint(self, checkpoint_obj): + pass + + +class OneInputOperator(Operator, ABC): + """Interface for stream operators with one input.""" + + @abstractmethod + def process_element(self, record): + pass + + def operator_type(self): + return OperatorType.ONE_INPUT + + +class TwoInputOperator(Operator, ABC): + """Interface for stream operators with two input""" + + @abstractmethod + def process_element(self, record1, record2): + pass + + def operator_type(self): + return OperatorType.TWO_INPUT + + +class StreamOperator(Operator, ABC): + """ + Basic interface for stream operators. Implementers would implement one of + :class:`OneInputOperator` or :class:`TwoInputOperator` to to create + operators that process elements. + """ + + def __init__(self, func): + self.func = func + self.collectors = None + self.runtime_context = None + + def open(self, collectors, runtime_context): + self.collectors = collectors + self.runtime_context = runtime_context + self.func.open(runtime_context) + + def finish(self): + pass + + def close(self): + self.func.close() + + def collect(self, record): + for collector in self.collectors: + collector.collect(record) + + def save_checkpoint(self): + self.func.save_checkpoint() + + def load_checkpoint(self, checkpoint_obj): + self.func.load_checkpoint(checkpoint_obj) + + +class SourceOperator(Operator, ABC): + @abstractmethod + def fetch(self): + pass + + +class SourceOperatorImpl(SourceOperator, StreamOperator): + """ + Operator to run a :class:`function.SourceFunction` + """ + + class SourceContextImpl(function.SourceContext): + def __init__(self, collectors): + self.collectors = collectors + + def collect(self, value): + for collector in self.collectors: + collector.collect(message.Record(value)) + + def __init__(self, func: SourceFunction): + assert isinstance(func, function.SourceFunction) + super().__init__(func) + self.source_context = None + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + self.source_context = SourceOperatorImpl.SourceContextImpl(collectors) + self.func.init(runtime_context.get_parallelism(), + runtime_context.get_task_index()) + + def fetch(self): + self.func.fetch(self.source_context) + + def operator_type(self): + return OperatorType.SOURCE + + +class MapOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.MapFunction` + """ + + def __init__(self, map_func: function.MapFunction): + assert isinstance(map_func, function.MapFunction) + super().__init__(map_func) + + def process_element(self, record): + self.collect(message.Record(self.func.map(record.value))) + + +class FlatMapOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.FlatMapFunction` + """ + + def __init__(self, flat_map_func: function.FlatMapFunction): + assert isinstance(flat_map_func, function.FlatMapFunction) + super().__init__(flat_map_func) + self.collection_collector = None + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + self.collection_collector = CollectionCollector(collectors) + + def process_element(self, record): + self.func.flat_map(record.value, self.collection_collector) + + +class FilterOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.FilterFunction` + """ + + def __init__(self, filter_func: function.FilterFunction): + assert isinstance(filter_func, function.FilterFunction) + super().__init__(filter_func) + + def process_element(self, record): + if self.func.filter(record.value): + self.collect(record) + + +class KeyByOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.KeyFunction` + """ + + def __init__(self, key_func: function.KeyFunction): + assert isinstance(key_func, function.KeyFunction) + super().__init__(key_func) + + def process_element(self, record): + key = self.func.key_by(record.value) + self.collect(message.KeyRecord(key, record.value)) + + +class ReduceOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.ReduceFunction` + """ + + def __init__(self, reduce_func: function.ReduceFunction): + assert isinstance(reduce_func, function.ReduceFunction) + super().__init__(reduce_func) + self.reduce_state = {} + + def open(self, collectors, runtime_context): + super().open(collectors, runtime_context) + + def process_element(self, record: message.KeyRecord): + key = record.key + value = record.value + if key in self.reduce_state: + old_value = self.reduce_state[key] + new_value = self.func.reduce(old_value, value) + self.reduce_state[key] = new_value + self.collect(message.Record(new_value)) + else: + self.reduce_state[key] = value + self.collect(record) + + +class SinkOperator(StreamOperator, OneInputOperator): + """ + Operator to run a :class:`function.SinkFunction` + """ + + def __init__(self, sink_func: function.SinkFunction): + assert isinstance(sink_func, function.SinkFunction) + super().__init__(sink_func) + + def process_element(self, record): + self.func.sink(record.value) + + +class UnionOperator(StreamOperator, OneInputOperator): + """Operator for union operation""" + + def __init__(self): + super().__init__(function.EmptyFunction()) + + def process_element(self, record): + self.collect(record) + + +class ChainedOperator(StreamOperator, ABC): + class ForwardCollector(Collector): + def __init__(self, succeeding_operator): + self.succeeding_operator = succeeding_operator + + def collect(self, record): + self.succeeding_operator.process_element(record) + + def __init__(self, operators, configs): + super().__init__(operators[0].func) + self.operators = operators + self.configs = configs + + def open(self, collectors, runtime_context): + # Dont' call super.open() as we `open` every operator separately. + num_operators = len(self.operators) + succeeding_collectors = [ + ChainedOperator.ForwardCollector(operator) + for operator in self.operators[1:] + ] + for i in range(0, num_operators - 1): + forward_collectors = [succeeding_collectors[i]] + self.operators[i].open( + forward_collectors, + self.__create_runtime_context(runtime_context, i)) + self.operators[-1].open( + collectors, + self.__create_runtime_context(runtime_context, num_operators - 1)) + + def operator_type(self) -> OperatorType: + return self.operators[0].operator_type() + + def __create_runtime_context(self, runtime_context, index): + def get_config(): + return self.configs[index] + + runtime_context.get_config = get_config + return runtime_context + + @staticmethod + def new_chained_operator(operators, configs): + operator_type = operators[0].operator_type() + logger.info( + "Building ChainedOperator from operators {} and configs {}." + .format(operators, configs)) + if operator_type == OperatorType.SOURCE: + return ChainedSourceOperator(operators, configs) + elif operator_type == OperatorType.ONE_INPUT: + return ChainedOneInputOperator(operators, configs) + elif operator_type == OperatorType.TWO_INPUT: + return ChainedTwoInputOperator(operators, configs) + else: + raise Exception("Current operator type is not supported") + + +class ChainedSourceOperator(SourceOperator, ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def fetch(self): + self.operators[0].fetch() + + +class ChainedOneInputOperator(ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def process_element(self, record): + self.operators[0].process_element(record) + + +class ChainedTwoInputOperator(ChainedOperator): + def __init__(self, operators, configs): + super().__init__(operators, configs) + + def process_element(self, record1, record2): + self.operators[0].process_element(record1, record2) + + +def load_chained_operator(chained_operator_bytes: bytes): + """Load chained operator from serialized operators and configs""" + serialized_operators, configs = gateway_client.deserialize( + chained_operator_bytes) + operators = [ + load_operator(desc_bytes) for desc_bytes in serialized_operators + ] + return ChainedOperator.new_chained_operator(operators, configs) + + +def load_operator(descriptor_operator_bytes: bytes): + """ + Deserialize `descriptor_operator_bytes` to get operator info, then + create streaming operator. + Note that this function must be kept in sync with + `io.ray.streaming.runtime.python.GraphPbBuilder.serializeOperator` + + Args: + descriptor_operator_bytes: serialized operator info + + Returns: + a streaming operator + """ + assert len(descriptor_operator_bytes) > 0 + function_desc_bytes, module_name, class_name \ + = gateway_client.deserialize(descriptor_operator_bytes) + if function_desc_bytes: + return create_operator_with_func( + function.load_function(function_desc_bytes)) + else: + assert module_name + assert class_name + mod = importlib.import_module(module_name) + cls = getattr(mod, class_name) + assert issubclass(cls, Operator) + print("cls", cls) + return cls() + + +_function_to_operator = { + function.SourceFunction: SourceOperatorImpl, + function.MapFunction: MapOperator, + function.FlatMapFunction: FlatMapOperator, + function.FilterFunction: FilterOperator, + function.KeyFunction: KeyByOperator, + function.ReduceFunction: ReduceOperator, + function.SinkFunction: SinkOperator, +} + + +def create_operator_with_func(func: function.Function): + """Create an operator according to a :class:`function.Function` + + Args: + func: a subclass of function.Function + + Returns: + an operator + """ + operator_class = None + super_classes = func.__class__.mro() + for super_class in super_classes: + operator_class = _function_to_operator.get(super_class, None) + if operator_class is not None: + break + assert operator_class is not None + return operator_class(func) diff --git a/streaming/python/partition.py b/streaming/python/partition.py new file mode 100644 index 00000000..fb30ba7c --- /dev/null +++ b/streaming/python/partition.py @@ -0,0 +1,129 @@ +import importlib +import inspect +from abc import ABC, abstractmethod + +from ray import cloudpickle +from ray.streaming.runtime import gateway_client + + +class Partition(ABC): + """Interface of the partitioning strategy.""" + + @abstractmethod + def partition(self, record, num_partition: int): + """Given a record and downstream partitions, determine which partition(s) + should receive the record. + + Args: + record: The record. + num_partition: num of partitions + Returns: + IDs of the downstream partitions that should receive the record. + """ + pass + + +class BroadcastPartition(Partition): + """Broadcast the record to all downstream partitions.""" + + def __init__(self): + self.__partitions = [] + + def partition(self, record, num_partition: int): + if len(self.__partitions) != num_partition: + self.__partitions = list(range(num_partition)) + return self.__partitions + + +class KeyPartition(Partition): + """Partition the record by the key.""" + + def __init__(self): + self.__partitions = [-1] + + def partition(self, key_record, num_partition: int): + # TODO support key group + self.__partitions[0] = abs(hash(key_record.key)) % num_partition + return self.__partitions + + +class RoundRobinPartition(Partition): + """Partition record to downstream tasks in a round-robin matter.""" + + def __init__(self): + self.__partitions = [-1] + self.seq = 0 + + def partition(self, key_record, num_partition: int): + self.seq = (self.seq + 1) % num_partition + self.__partitions[0] = self.seq + return self.__partitions + + +class ForwardPartition(Partition): + """Default partition for operator if the operator can be chained with + succeeding operators.""" + + def __init__(self): + self.__partitions = [0] + + def partition(self, key_record, num_partition: int): + return self.__partitions + + +class SimplePartition(Partition): + """Wrap a python function as subclass of :class:`Partition`""" + + def __init__(self, func): + self.func = func + + def partition(self, record, num_partition: int): + return self.func(record, num_partition) + + +def serialize(partition_func): + """ + Serialize the partition function so that it can be deserialized by + :func:`deserialize` + """ + return cloudpickle.dumps(partition_func) + + +def deserialize(partition_bytes): + """Deserialize the binary partition function serialized by + :func:`serialize`""" + return cloudpickle.loads(partition_bytes) + + +def load_partition(descriptor_partition_bytes: bytes): + """ + Deserialize `descriptor_partition_bytes` to get partition info, then + get or load partition function. + Note that this function must be kept in sync with + `io.ray.streaming.runtime.python.GraphPbBuilder.serializePartition` + + Args: + descriptor_partition_bytes: serialized partition info + + Returns: + partition function + """ + assert len(descriptor_partition_bytes) > 0 + partition_bytes, module_name, function_name =\ + gateway_client.deserialize(descriptor_partition_bytes) + if partition_bytes: + return deserialize(partition_bytes) + else: + assert module_name + mod = importlib.import_module(module_name) + assert function_name + func = getattr(mod, function_name) + # If func is a python function, user partition is a simple python + # function, which will be wrapped as a SimplePartition. + # If func is a python class, user partition is a sub class + # of Partition. + if inspect.isfunction(func): + return SimplePartition(func) + else: + assert issubclass(func, Partition) + return func() diff --git a/streaming/python/runtime/__init__.py b/streaming/python/runtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/streaming/python/runtime/command.py b/streaming/python/runtime/command.py new file mode 100644 index 00000000..cc5f02e1 --- /dev/null +++ b/streaming/python/runtime/command.py @@ -0,0 +1,30 @@ +class BaseWorkerCmd: + """ + base worker cmd + """ + + def __init__(self, actor_id): + self.from_actor_id = actor_id + + +class WorkerCommitReport(BaseWorkerCmd): + """ + worker commit report + """ + + def __init__(self, actor_id, commit_checkpoint_id): + super().__init__(actor_id) + self.commit_checkpoint_id = commit_checkpoint_id + + +class WorkerRollbackRequest(BaseWorkerCmd): + """ + worker rollback request + """ + + def __init__(self, actor_id, exception_msg): + super().__init__(actor_id) + self.__exception_msg = exception_msg + + def exception_msg(self): + return self.__exception_msg diff --git a/streaming/python/runtime/context_backend.py b/streaming/python/runtime/context_backend.py new file mode 100644 index 00000000..65e811cf --- /dev/null +++ b/streaming/python/runtime/context_backend.py @@ -0,0 +1,117 @@ +import logging +import os +from abc import ABC, abstractmethod +from os import path + +from ray.streaming.config import ConfigHelper, Config + +logger = logging.getLogger(__name__) + + +class ContextBackend(ABC): + @abstractmethod + def get(self, key): + pass + + @abstractmethod + def put(self, key, value): + pass + + @abstractmethod + def remove(self, key): + pass + + +class MemoryContextBackend(ContextBackend): + def __init__(self, conf): + self.__dic = dict() + + def get(self, key): + return self.__dic.get(key) + + def put(self, key, value): + self.__dic[key] = value + + def remove(self, key): + if key in self.__dic: + del self.__dic[key] + + +class LocalFileContextBackend(ContextBackend): + def __init__(self, conf): + self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf) + logger.info("Start init local file state backend, root_dir={}.".format( + self.__dir)) + try: + os.mkdir(self.__dir) + except FileExistsError: + logger.info("dir already exists, skipped.") + + def put(self, key, value): + logger.info("Put value of key {} start.".format(key)) + with open(self.__gen_file_path(key), "wb") as f: + f.write(value) + + def get(self, key): + logger.info("Get value of key {} start.".format(key)) + full_path = self.__gen_file_path(key) + if not os.path.isfile(full_path): + return None + with open(full_path, "rb") as f: + return f.read() + + def remove(self, key): + logger.info("Remove value of key {} start.".format(key)) + try: + os.remove(self.__gen_file_path(key)) + except Exception: + # ignore exception + pass + + def rename(self, src, dst): + logger.info("rename {} to {}".format(src, dst)) + os.rename(self.__gen_file_path(src), self.__gen_file_path(dst)) + + def exists(self, key) -> bool: + return os.path.exists(key) + + def __gen_file_path(self, key): + return path.join(self.__dir, key) + + +class AtomicFsContextBackend(LocalFileContextBackend): + def __init__(self, conf): + super().__init__(conf) + self.__tmp_flag = "_tmp" + + def put(self, key, value): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key) and not super().exists(key): + super().rename(tmp_key, key) + super().put(tmp_key, value) + super().remove(key) + super().rename(tmp_key, key) + + def get(self, key): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key) and not super().exists(key): + return super().get(tmp_key) + return super().get(key) + + def remove(self, key): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key): + super().remove(tmp_key) + super().remove(key) + + +class ContextBackendFactory: + @staticmethod + def get_context_backend(worker_config) -> ContextBackend: + backend_type = ConfigHelper.get_cp_context_backend_type(worker_config) + context_backend = None + if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE: + context_backend = AtomicFsContextBackend(worker_config) + elif backend_type == Config.CP_STATE_BACKEND_MEMORY: + context_backend = MemoryContextBackend(worker_config) + return context_backend diff --git a/streaming/python/runtime/failover.py b/streaming/python/runtime/failover.py new file mode 100644 index 00000000..702cdbab --- /dev/null +++ b/streaming/python/runtime/failover.py @@ -0,0 +1,30 @@ +class Barrier: + """ + barrier + """ + + def __init__(self, id): + self.id = id + + def __str__(self): + return "Barrier [id:%s]" % self.id + + +class OpCheckpointInfo: + """ + operator checkpoint info + """ + + def __init__(self, + operator_point=None, + input_points=None, + output_points=None, + checkpoint_id=None): + if input_points is None: + input_points = {} + if output_points is None: + output_points = {} + self.operator_point = operator_point + self.input_points = input_points + self.output_points = output_points + self.checkpoint_id = checkpoint_id diff --git a/streaming/python/runtime/gateway_client.py b/streaming/python/runtime/gateway_client.py new file mode 100644 index 00000000..455ccf23 --- /dev/null +++ b/streaming/python/runtime/gateway_client.py @@ -0,0 +1,78 @@ +# -*- coding: UTF-8 -*- +"""Module to interact between java and python +""" + +import msgpack +import ray + + +class GatewayClient: + """GatewayClient is used to interact with `PythonGateway` java actor""" + + _PYTHON_GATEWAY_CLASSNAME = \ + b"io.ray.streaming.runtime.python.PythonGateway" + + def __init__(self): + self._python_gateway_actor = ray.java_actor_class( + GatewayClient._PYTHON_GATEWAY_CLASSNAME).remote() + + def create_streaming_context(self): + call = self._python_gateway_actor.createStreamingContext.remote() + return deserialize(ray.get(call)) + + def with_config(self, conf): + call = self._python_gateway_actor.withConfig.remote(serialize(conf)) + ray.get(call) + + def execute(self, job_name): + call = self._python_gateway_actor.execute.remote(serialize(job_name)) + ray.get(call) + + def create_py_stream_source(self, serialized_func): + assert isinstance(serialized_func, bytes) + call = self._python_gateway_actor.createPythonStreamSource \ + .remote(serialized_func) + return deserialize(ray.get(call)) + + def create_py_func(self, serialized_func): + assert isinstance(serialized_func, bytes) + call = self._python_gateway_actor.createPyFunc.remote(serialized_func) + return deserialize(ray.get(call)) + + def create_py_partition(self, serialized_partition): + assert isinstance(serialized_partition, bytes) + call = self._python_gateway_actor.createPyPartition \ + .remote(serialized_partition) + return deserialize(ray.get(call)) + + def union(self, *streams): + serialized_streams = serialize(streams) + call = self._python_gateway_actor.union \ + .remote(serialized_streams) + return deserialize(ray.get(call)) + + def call_function(self, java_class, java_function, *args): + java_params = serialize([java_class, java_function] + list(args)) + call = self._python_gateway_actor.callFunction.remote(java_params) + return deserialize(ray.get(call)) + + def call_method(self, java_object, java_method, *args): + java_params = serialize([java_object, java_method] + list(args)) + call = self._python_gateway_actor.callMethod.remote(java_params) + return deserialize(ray.get(call)) + + def new_instance(self, java_class_name): + call = self._python_gateway_actor.newInstance.remote( + serialize(java_class_name)) + return deserialize(ray.get(call)) + + +def serialize(obj) -> bytes: + """Serialize a python object which can be deserialized by `PythonGateway` + """ + return msgpack.packb(obj, use_bin_type=True) + + +def deserialize(data: bytes): + """Deserialize the binary data serialized by `PythonGateway`""" + return msgpack.unpackb(data, raw=False, strict_map_key=False) diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py new file mode 100644 index 00000000..acfd5b83 --- /dev/null +++ b/streaming/python/runtime/graph.py @@ -0,0 +1,155 @@ +import enum +import logging + +import ray +import ray.streaming.generated.remote_call_pb2 as remote_call_pb +import ray.streaming.operator as operator +import ray.streaming.partition as partition +from ray._raylet import ActorID +from ray.actor import ActorHandle +from ray.streaming.config import Config +from ray.streaming.generated.streaming_pb2 import Language + +logger = logging.getLogger(__name__) + + +class NodeType(enum.Enum): + """ + SOURCE: Sources are where your program reads its input from + + TRANSFORM: Operators transform one or more DataStreams into a new + DataStream. Programs can combine multiple transformations into + sophisticated dataflow topologies. + + SINK: Sinks consume DataStreams and forward them to files, sockets, + external systems, or print them. + """ + SOURCE = 0 + TRANSFORM = 1 + SINK = 2 + + +class ExecutionEdge: + def __init__(self, execution_edge_pb, language): + self.source_execution_vertex_id = execution_edge_pb \ + .source_execution_vertex_id + self.target_execution_vertex_id = execution_edge_pb \ + .target_execution_vertex_id + partition_bytes = execution_edge_pb.partition + # Sink node doesn't have partition function, + # so we only deserialize partition_bytes when it's not None or empty + if language == Language.PYTHON and partition_bytes: + self.partition = partition.load_partition(partition_bytes) + + +class ExecutionVertex: + worker_actor: ActorHandle + + def __init__(self, execution_vertex_pb): + self.execution_vertex_id = execution_vertex_pb.execution_vertex_id + self.execution_job_vertex_id = execution_vertex_pb \ + .execution_job_vertex_id + self.execution_job_vertex_name = execution_vertex_pb \ + .execution_job_vertex_name + self.execution_vertex_index = execution_vertex_pb\ + .execution_vertex_index + self.parallelism = execution_vertex_pb.parallelism + if execution_vertex_pb\ + .language == Language.PYTHON: + # python operator descriptor + operator_bytes = execution_vertex_pb.operator + if execution_vertex_pb.chained: + logger.info("Load chained operator") + self.stream_operator = operator.load_chained_operator( + operator_bytes) + else: + logger.info("Load operator") + self.stream_operator = operator.load_operator(operator_bytes) + self.worker_actor = None + if execution_vertex_pb.worker_actor: + self.worker_actor = ray.actor.ActorHandle. \ + _deserialization_helper(execution_vertex_pb.worker_actor) + self.container_id = execution_vertex_pb.container_id + self.build_time = execution_vertex_pb.build_time + self.language = execution_vertex_pb.language + self.config = execution_vertex_pb.config + self.resource = execution_vertex_pb.resource + + @property + def execution_vertex_name(self): + return "{}_{}_{}".format(self.execution_job_vertex_id, + self.execution_job_vertex_name, + self.execution_vertex_id) + + +class ExecutionVertexContext: + actor_id: ActorID + execution_vertex: ExecutionVertex + + def __init__( + self, + execution_vertex_context_pb: remote_call_pb.ExecutionVertexContext + ): + self.execution_vertex = ExecutionVertex( + execution_vertex_context_pb.current_execution_vertex) + self.job_name = self.execution_vertex.config[Config.STREAMING_JOB_NAME] + self.exe_vertex_name = self.execution_vertex.execution_vertex_name + self.actor_id = self.execution_vertex.worker_actor._ray_actor_id + self.upstream_execution_vertices = [ + ExecutionVertex(vertex) for vertex in + execution_vertex_context_pb.upstream_execution_vertices + ] + self.downstream_execution_vertices = [ + ExecutionVertex(vertex) for vertex in + execution_vertex_context_pb.downstream_execution_vertices + ] + self.input_execution_edges = [ + ExecutionEdge(edge, self.execution_vertex.language) + for edge in execution_vertex_context_pb.input_execution_edges + ] + self.output_execution_edges = [ + ExecutionEdge(edge, self.execution_vertex.language) + for edge in execution_vertex_context_pb.output_execution_edges + ] + + def get_parallelism(self): + return self.execution_vertex.parallelism + + def get_upstream_parallelism(self): + if self.upstream_execution_vertices: + return self.upstream_execution_vertices[0].parallelism + return 0 + + def get_downstream_parallelism(self): + if self.downstream_execution_vertices: + return self.downstream_execution_vertices[0].parallelism + return 0 + + @property + def build_time(self): + return self.execution_vertex.build_time + + @property + def stream_operator(self): + return self.execution_vertex.stream_operator + + @property + def config(self): + return self.execution_vertex.config + + def get_task_id(self): + return self.execution_vertex.execution_vertex_id + + def get_source_actor_by_execution_vertex_id(self, execution_vertex_id): + for execution_vertex in self.upstream_execution_vertices: + if execution_vertex.execution_vertex_id == execution_vertex_id: + return execution_vertex.worker_actor + raise Exception( + "Vertex {} does not exist!".format(execution_vertex_id)) + + def get_target_actor_by_execution_vertex_id(self, execution_vertex_id): + for execution_vertex in self.downstream_execution_vertices: + if execution_vertex.execution_vertex_id == execution_vertex_id: + return execution_vertex.worker_actor + raise Exception( + "Vertex {} does not exist!".format(execution_vertex_id)) diff --git a/streaming/python/runtime/processor.py b/streaming/python/runtime/processor.py new file mode 100644 index 00000000..1083713e --- /dev/null +++ b/streaming/python/runtime/processor.py @@ -0,0 +1,127 @@ +import logging +from abc import ABC, abstractmethod + +import ray.streaming.context as context +from ray.streaming import message +from ray.streaming.operator import OperatorType + +logger = logging.getLogger(__name__) + + +class Processor(ABC): + """The base interface for all processors.""" + + @abstractmethod + def open(self, collectors, runtime_context): + pass + + @abstractmethod + def process(self, record: message.Record): + pass + + @abstractmethod + def close(self): + pass + + @abstractmethod + def save_checkpoint(self): + pass + + @abstractmethod + def load_checkpoint(self, checkpoint_obj): + pass + + +class StreamingProcessor(Processor, ABC): + """StreamingProcessor is a process unit for a operator.""" + + def __init__(self, operator): + self.operator = operator + self.collectors = None + self.runtime_context = None + + def open(self, collectors, runtime_context: context.RuntimeContext): + self.collectors = collectors + self.runtime_context = runtime_context + if self.operator is not None: + self.operator.open(collectors, runtime_context) + logger.info("Opened Processor {}".format(self)) + + def close(self): + self.operator.close() + + def save_checkpoint(self): + self.operator.save_checkpoint() + + def load_checkpoint(self, checkpoint_obj): + self.operator.load_checkpoint(checkpoint_obj) + + +class SourceProcessor(StreamingProcessor): + """Processor for :class:`ray.streaming.operator.SourceOperator` """ + + def __init__(self, operator): + super().__init__(operator) + + def process(self, record): + raise Exception("SourceProcessor should not process record") + + def fetch(self): + self.operator.fetch() + + +class OneInputProcessor(StreamingProcessor): + """Processor for stream operator with one input""" + + def __init__(self, operator): + super().__init__(operator) + + def process(self, record): + self.operator.process_element(record) + + +class TwoInputProcessor(StreamingProcessor): + """Processor for stream operator with two inputs""" + + def __init__(self, operator): + super().__init__(operator) + self.left_stream = None + self.right_stream = None + + def process(self, record: message.Record): + if record.stream == self.left_stream: + self.operator.process_element(record, None) + else: + self.operator.process_element(None, record) + + @property + def left_stream(self): + return self.left_stream + + @left_stream.setter + def left_stream(self, value): + self._left_stream = value + + @property + def right_stream(self): + return self.right_stream + + @right_stream.setter + def right_stream(self, value): + self.right_stream = value + + +def build_processor(operator_instance): + """Create a processor for the given operator.""" + operator_type = operator_instance.operator_type() + logger.info( + "Building StreamProcessor, operator type = {}, operator = {}.".format( + operator_type, operator_instance)) + if operator_type == OperatorType.SOURCE: + return SourceProcessor(operator_instance) + elif operator_type == OperatorType.ONE_INPUT: + return OneInputProcessor(operator_instance) + elif operator_type == OperatorType.TWO_INPUT: + return TwoInputProcessor(operator_instance) + else: + raise Exception("Current operator type is not supported") diff --git a/streaming/python/runtime/remote_call.py b/streaming/python/runtime/remote_call.py new file mode 100644 index 00000000..4f5f082e --- /dev/null +++ b/streaming/python/runtime/remote_call.py @@ -0,0 +1,95 @@ +import logging +import os +import ray +import time +from enum import Enum + +from ray.actor import ActorHandle +from ray.streaming.generated import remote_call_pb2 +from ray.streaming.runtime.command\ + import WorkerCommitReport, WorkerRollbackRequest + +logger = logging.getLogger(__name__) + + +class CallResult: + """ + Call Result + """ + + def __init__(self, success, result_code, result_msg, result_obj): + self.success = success + self.result_code = result_code + self.result_msg = result_msg + self.result_obj = result_obj + + @staticmethod + def success(payload=None): + return CallResult(True, CallResultEnum.SUCCESS, None, payload) + + @staticmethod + def fail(payload=None): + return CallResult(False, CallResultEnum.FAILED, None, payload) + + @staticmethod + def skipped(msg=None): + return CallResult(True, CallResultEnum.SKIPPED, msg, None) + + def is_success(self): + if self.result_code is CallResultEnum.SUCCESS: + return True + + return False + + +class CallResultEnum(Enum): + """ + call result enum + """ + + SUCCESS = 0 + FAILED = 1 + SKIPPED = 2 + + +class RemoteCallMst: + """ + remote call job master + """ + + @staticmethod + def request_job_worker_rollback(master: ActorHandle, + request: WorkerRollbackRequest): + logger.info("Remote call mst: request job worker rollback start.") + request_pb = remote_call_pb2.BaseWorkerCmd() + request_pb.actor_id = request.from_actor_id + request_pb.timestamp = int(time.time() * 1000.0) + rollback_request_pb = remote_call_pb2.WorkerRollbackRequest() + rollback_request_pb.exception_msg = request.exception_msg() + rollback_request_pb.worker_hostname = os.uname()[1] + rollback_request_pb.worker_pid = str(os.getpid()) + request_pb.detail.Pack(rollback_request_pb) + return_ids = master.requestJobWorkerRollback\ + .remote(request_pb.SerializeToString()) + result = remote_call_pb2.BoolResult() + result.ParseFromString(ray.get(return_ids)) + logger.info("Remote call mst: request job worker rollback finish.") + return result.boolRes + + @staticmethod + def report_job_worker_commit(master: ActorHandle, + report: WorkerCommitReport): + logger.info("Remote call mst: report job worker commit start.") + report_pb = remote_call_pb2.BaseWorkerCmd() + + report_pb.actor_id = report.from_actor_id + report_pb.timestamp = int(time.time() * 1000.0) + wk_commit = remote_call_pb2.WorkerCommitReport() + wk_commit.commit_checkpoint_id = report.commit_checkpoint_id + report_pb.detail.Pack(wk_commit) + return_id = master.reportJobWorkerCommit\ + .remote(report_pb.SerializeToString()) + result = remote_call_pb2.BoolResult() + result.ParseFromString(ray.get(return_id)) + logger.info("Remote call mst: report job worker commit finish.") + return result.boolRes diff --git a/streaming/python/runtime/serialization.py b/streaming/python/runtime/serialization.py new file mode 100644 index 00000000..600e1084 --- /dev/null +++ b/streaming/python/runtime/serialization.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +import pickle +import msgpack +from ray.streaming import message + +RECORD_TYPE_ID = 0 +KEY_RECORD_TYPE_ID = 1 +CROSS_LANG_TYPE_ID = 0 +JAVA_TYPE_ID = 1 +PYTHON_TYPE_ID = 2 + + +class Serializer(ABC): + @abstractmethod + def serialize(self, obj): + pass + + @abstractmethod + def deserialize(self, serialized_bytes): + pass + + +class PythonSerializer(Serializer): + def serialize(self, obj): + return pickle.dumps(obj) + + def deserialize(self, serialized_bytes): + return pickle.loads(serialized_bytes) + + +class CrossLangSerializer(Serializer): + """Serialize stream element between java/python""" + + def serialize(self, obj): + if type(obj) is message.Record: + fields = [RECORD_TYPE_ID, obj.stream, obj.value] + elif type(obj) is message.KeyRecord: + fields = [KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] + else: + raise Exception("Unsupported value {}".format(obj)) + return msgpack.packb(fields, use_bin_type=True) + + def deserialize(self, data): + fields = msgpack.unpackb(data, raw=False) + if fields[0] == RECORD_TYPE_ID: + stream, value = fields[1:] + record = message.Record(value) + record.stream = stream + return record + elif fields[0] == KEY_RECORD_TYPE_ID: + stream, key, value = fields[1:] + key_record = message.KeyRecord(key, value) + key_record.stream = stream + return key_record + else: + raise Exception("Unsupported type id {}, type {}".format( + fields[0], type(fields[0]))) diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py new file mode 100644 index 00000000..54ec3cf3 --- /dev/null +++ b/streaming/python/runtime/task.py @@ -0,0 +1,382 @@ +import logging +import pickle +import threading +import time +import typing +from abc import ABC, abstractmethod +from typing import Optional + +from ray.streaming.collector import OutputCollector +from ray.streaming.config import Config +from ray.streaming.context import RuntimeContextImpl +from ray.streaming.generated import remote_call_pb2 +from ray.streaming.runtime import serialization +from ray.streaming.runtime.command import WorkerCommitReport +from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo +from ray.streaming.runtime.remote_call import RemoteCallMst +from ray.streaming.runtime.serialization import \ + PythonSerializer, CrossLangSerializer +from ray.streaming.runtime.transfer import CheckpointBarrier +from ray.streaming.runtime.transfer import DataMessage +from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader +from ray.streaming.runtime.transfer import ChannelRecoverInfo +from ray.streaming.runtime.transfer import ChannelInterruptException + +if typing.TYPE_CHECKING: + from ray.streaming.runtime.worker import JobWorker + from ray.streaming.runtime.processor import Processor, SourceProcessor + +logger = logging.getLogger(__name__) + + +class StreamTask(ABC): + """Base class for all streaming tasks. Each task runs a processor.""" + + def __init__(self, task_id: int, processor: "Processor", + worker: "JobWorker", last_checkpoint_id: int): + self.worker_context = worker.worker_context + self.vertex_context = worker.execution_vertex_context + self.task_id = task_id + self.processor = processor + self.worker = worker + self.config: dict = worker.config + self.reader: Optional[DataReader] = None + self.writer: Optional[DataWriter] = None + self.is_initial_state = True + self.last_checkpoint_id: int = last_checkpoint_id + self.thread = threading.Thread(target=self.run, daemon=True) + + def do_checkpoint(self, checkpoint_id: int, input_points): + logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format( + checkpoint_id, input_points)) + + output_points = None + if self.writer is not None: + output_points = self.writer.get_output_checkpoints() + + operator_checkpoint = self.processor.save_checkpoint() + op_checkpoint_info = OpCheckpointInfo( + operator_checkpoint, input_points, output_points, checkpoint_id) + self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id) + + barrier_pb = remote_call_pb2.Barrier() + barrier_pb.id = checkpoint_id + byte_buffer = barrier_pb.SerializeToString() + if self.writer is not None: + self.writer.broadcast_barrier(checkpoint_id, byte_buffer) + logger.info("Operator checkpoint {} finish.".format(checkpoint_id)) + + def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id): + logger.info( + "Start to save cp state and report, checkpoint id is {}.".format( + checkpoint_id)) + self.__save_cp(op_checkpoint_info, checkpoint_id) + self.__report_commit(checkpoint_id) + self.last_checkpoint_id = checkpoint_id + + def __save_cp(self, op_checkpoint_info, checkpoint_id): + logger.info("save operator cp, op_checkpoint_info={}".format( + op_checkpoint_info)) + cp_bytes = pickle.dumps(op_checkpoint_info) + self.worker.context_backend.put( + self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes) + + def __report_commit(self, checkpoint_id: int): + logger.info("Report commit, checkpoint id {}.".format(checkpoint_id)) + report = WorkerCommitReport(self.vertex_context.actor_id.binary(), + checkpoint_id) + RemoteCallMst.report_job_worker_commit(self.worker.master_actor, + report) + + def clear_expired_cp_state(self, checkpoint_id): + cp_key = self.__gen_op_checkpoint_key(checkpoint_id) + self.worker.context_backend.remove(cp_key) + + def clear_expired_queue_msg(self, checkpoint_id): + # clear operator checkpoint + if self.writer is not None: + self.writer.clear_checkpoint(checkpoint_id) + + def request_rollback(self, exception_msg: str): + self.worker.request_rollback(exception_msg) + + def __gen_op_checkpoint_key(self, checkpoint_id): + op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str( + self.vertex_context.job_name) + "_" + str( + self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id) + logger.info( + "Generate op checkpoint key {}. ".format(op_checkpoint_key)) + return op_checkpoint_key + + def prepare_task(self, is_recreate: bool): + logger.info( + "Preparing stream task, is_recreate={}.".format(is_recreate)) + channel_conf = dict(self.worker.config) + channel_size = int( + self.worker.config.get(Config.CHANNEL_SIZE, + Config.CHANNEL_SIZE_DEFAULT)) + channel_conf[Config.CHANNEL_SIZE] = channel_size + channel_conf[Config.CHANNEL_TYPE] = self.worker.config \ + .get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) + + execution_vertex_context = self.worker.execution_vertex_context + build_time = execution_vertex_context.build_time + + # when use memory state, if actor throw exception, will miss state + op_checkpoint_info = OpCheckpointInfo() + + cp_bytes = None + # get operator checkpoint + if is_recreate: + cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id) + logger.info("Getting task checkpoints from state, " + "cpKey={}, checkpointId={}.".format( + cp_key, self.last_checkpoint_id)) + cp_bytes = self.worker.context_backend.get(cp_key) + if cp_bytes is None: + msg = "Task recover failed, checkpoint is null!"\ + "cpKey={}".format(cp_key) + raise RuntimeError(msg) + + if cp_bytes is not None: + op_checkpoint_info = pickle.loads(cp_bytes) + self.processor.load_checkpoint(op_checkpoint_info.operator_point) + logger.info("Stream task recover from checkpoint state," + "checkpoint bytes len={}, checkpointInfo={}.".format( + cp_bytes.__len__(), op_checkpoint_info)) + + # writers + collectors = [] + output_actors_map = {} + for edge in execution_vertex_context.output_execution_edges: + target_task_id = edge.target_execution_vertex_id + target_actor = execution_vertex_context \ + .get_target_actor_by_execution_vertex_id(target_task_id) + channel_name = ChannelID.gen_id(self.task_id, target_task_id, + build_time) + output_actors_map[channel_name] = target_actor + + if len(output_actors_map) > 0: + channel_str_ids = list(output_actors_map.keys()) + target_actors = list(output_actors_map.values()) + logger.info("Create DataWriter channel_ids {}," + "target_actors {}, output_points={}.".format( + channel_str_ids, target_actors, + op_checkpoint_info.output_points)) + self.writer = DataWriter(channel_str_ids, target_actors, + channel_conf) + logger.info("Create DataWriter succeed channel_ids {}, " + "target_actors {}.".format(channel_str_ids, + target_actors)) + for edge in execution_vertex_context.output_execution_edges: + collectors.append( + OutputCollector(self.writer, channel_str_ids, + target_actors, edge.partition)) + + # readers + input_actor_map = {} + for edge in execution_vertex_context.input_execution_edges: + source_task_id = edge.source_execution_vertex_id + source_actor = execution_vertex_context \ + .get_source_actor_by_execution_vertex_id(source_task_id) + channel_name = ChannelID.gen_id(source_task_id, self.task_id, + build_time) + input_actor_map[channel_name] = source_actor + + if len(input_actor_map) > 0: + channel_str_ids = list(input_actor_map.keys()) + from_actors = list(input_actor_map.values()) + logger.info("Create DataReader, channels {}," + "input_actors {}, input_points={}.".format( + channel_str_ids, from_actors, + op_checkpoint_info.input_points)) + self.reader = DataReader(channel_str_ids, from_actors, + channel_conf) + + def exit_handler(): + # Make DataReader stop read data when MockQueue destructor + # gets called to avoid crash + self.cancel_task() + + import atexit + atexit.register(exit_handler) + + runtime_context = RuntimeContextImpl( + self.worker.task_id, + execution_vertex_context.execution_vertex.execution_vertex_index, + execution_vertex_context.get_parallelism(), + config=channel_conf, + job_config=channel_conf) + logger.info("open Processor {}".format(self.processor)) + self.processor.open(collectors, runtime_context) + + # immediately save cp. In case of FO in cp 0 + # or use old cp in multi node FO. + self.__save_cp(op_checkpoint_info, self.last_checkpoint_id) + + def recover(self, is_recreate: bool): + self.prepare_task(is_recreate) + + recover_info = ChannelRecoverInfo() + if self.reader is not None: + recover_info = self.reader.get_channel_recover_info() + + self.thread.start() + + logger.info("Start operator success.") + return recover_info + + @abstractmethod + def run(self): + pass + + @abstractmethod + def cancel_task(self): + pass + + @abstractmethod + def commit_trigger(self, barrier: Barrier) -> bool: + pass + + +class InputStreamTask(StreamTask): + """Base class for stream tasks that execute a + :class:`runtime.processor.OneInputProcessor` or + :class:`runtime.processor.TwoInputProcessor` """ + + def commit_trigger(self, barrier): + raise RuntimeError( + "commit_trigger is only supported in SourceStreamTask.") + + def __init__(self, task_id, processor_instance, worker, + last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) + self.running = True + self.stopped = False + self.read_timeout_millis = \ + int(worker.config.get(Config.READ_TIMEOUT_MS, + Config.DEFAULT_READ_TIMEOUT_MS)) + self.python_serializer = PythonSerializer() + self.cross_lang_serializer = CrossLangSerializer() + + def run(self): + logger.info("Input task thread start.") + try: + while self.running: + self.worker.initial_state_lock.acquire() + try: + item = self.reader.read(self.read_timeout_millis) + self.is_initial_state = False + finally: + self.worker.initial_state_lock.release() + + if item is None: + continue + + if isinstance(item, DataMessage): + msg_data = item.body + type_id = msg_data[0] + if type_id == serialization.PYTHON_TYPE_ID: + msg = self.python_serializer.deserialize(msg_data[1:]) + else: + msg = self.cross_lang_serializer.deserialize( + msg_data[1:]) + self.processor.process(msg) + elif isinstance(item, CheckpointBarrier): + logger.info("Got barrier:{}".format(item)) + logger.info("Start to do checkpoint {}.".format( + item.checkpoint_id)) + + input_points = item.get_input_checkpoints() + + self.do_checkpoint(item.checkpoint_id, input_points) + logger.info("Do checkpoint {} success.".format( + item.checkpoint_id)) + else: + raise RuntimeError( + "Unknown item type! item={}".format(item)) + + except ChannelInterruptException: + logger.info("queue has stopped.") + except BaseException as e: + logger.exception( + "Last success checkpointId={}, now occur error.".format( + self.last_checkpoint_id)) + self.request_rollback(str(e)) + + logger.info("Source fetcher thread exit.") + self.stopped = True + + def cancel_task(self): + self.running = False + while not self.stopped: + time.sleep(0.5) + pass + + +class OneInputStreamTask(InputStreamTask): + """A stream task for executing :class:`runtime.processor.OneInputProcessor` + """ + + def __init__(self, task_id, processor_instance, worker, + last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) + + +class SourceStreamTask(StreamTask): + """A stream task for executing :class:`runtime.processor.SourceProcessor` + """ + processor: "SourceProcessor" + + def __init__(self, task_id: int, processor_instance: "SourceProcessor", + worker: "JobWorker", last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) + self.running = True + self.stopped = False + self.__pending_barrier: Optional[Barrier] = None + + def run(self): + logger.info("Source task thread start.") + try: + while self.running: + self.processor.fetch() + # check checkpoint + if self.__pending_barrier is not None: + # source fetcher only have outputPoints + barrier = self.__pending_barrier + logger.info("Start to do checkpoint {}.".format( + barrier.id)) + self.do_checkpoint(barrier.id, barrier) + logger.info("Finish to do checkpoint {}.".format( + barrier.id)) + self.__pending_barrier = None + + except ChannelInterruptException: + logger.info("queue has stopped.") + except Exception as e: + logger.exception( + "Last success checkpointId={}, now occur error.".format( + self.last_checkpoint_id)) + self.request_rollback(str(e)) + + logger.info("Source fetcher thread exit.") + self.stopped = True + + def commit_trigger(self, barrier): + if self.__pending_barrier is not None: + logger.warning( + "Last barrier is not broadcast now, skip this barrier trigger." + ) + return False + + self.__pending_barrier = barrier + return True + + def cancel_task(self): + self.running = False + while not self.stopped: + time.sleep(0.5) + pass diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py new file mode 100644 index 00000000..4cb482de --- /dev/null +++ b/streaming/python/runtime/transfer.py @@ -0,0 +1,506 @@ +import logging +import random +from queue import Queue +from typing import List +from enum import Enum +from abc import ABC, abstractmethod + +import ray +import ray.streaming._streaming as _streaming +import ray.streaming.generated.streaming_pb2 as streaming_pb +from ray.actor import ActorHandle +from ray.streaming.config import Config +from ray._raylet import JavaFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor +from ray._raylet import Language + +CHANNEL_ID_LEN = ray.ObjectID.nil().size() +logger = logging.getLogger(__name__) + + +class ChannelID: + """ + ChannelID is used to identify a transfer channel between + a upstream worker and downstream worker. + """ + + def __init__(self, channel_id_str: str): + """ + Args: + channel_id_str: string representation of channel id + """ + self.channel_id_str = channel_id_str + self.object_qid = ray.ObjectRef( + channel_id_str_to_bytes(channel_id_str)) + + def __eq__(self, other): + if other is None: + return False + if type(other) is ChannelID: + return self.channel_id_str == other.channel_id_str + else: + return False + + def __hash__(self): + return hash(self.channel_id_str) + + def __repr__(self): + return self.channel_id_str + + @staticmethod + def gen_random_id(): + """Generate a random channel id string + """ + res = "" + for i in range(CHANNEL_ID_LEN * 2): + res += str(chr(random.randint(0, 5) + ord("A"))) + return res + + @staticmethod + def gen_id(from_index, to_index, ts): + """Generate channel id, which is `CHANNEL_ID_LEN` character""" + channel_id = bytearray(CHANNEL_ID_LEN) + for i in range(11, 7, -1): + channel_id[i] = ts & 0xff + ts >>= 8 + channel_id[16] = (from_index & 0xffff) >> 8 + channel_id[17] = (from_index & 0xff) + channel_id[18] = (to_index & 0xffff) >> 8 + channel_id[19] = (to_index & 0xff) + return channel_bytes_to_str(bytes(channel_id)) + + +def channel_id_str_to_bytes(channel_id_str): + """ + Args: + channel_id_str: string representation of channel id + + Returns: + bytes representation of channel id + """ + assert type(channel_id_str) in [str, bytes] + if isinstance(channel_id_str, bytes): + return channel_id_str + qid_bytes = bytes.fromhex(channel_id_str) + assert len(qid_bytes) == CHANNEL_ID_LEN + return qid_bytes + + +def channel_bytes_to_str(id_bytes): + """ + Args: + id_bytes: bytes representation of channel id + + Returns: + string representation of channel id + """ + assert type(id_bytes) in [str, bytes] + if isinstance(id_bytes, str): + return id_bytes + return bytes.hex(id_bytes) + + +class Message(ABC): + @property + @abstractmethod + def body(self): + """Message data""" + pass + + @property + @abstractmethod + def timestamp(self): + """Get timestamp when item is written by upstream DataWriter + """ + pass + + @property + @abstractmethod + def channel_id(self): + """Get string id of channel where data is coming from""" + pass + + @property + @abstractmethod + def message_id(self): + """Get message id of the message""" + pass + + +class DataMessage(Message): + """ + DataMessage represents data between upstream and downstream operator. + """ + + def __init__(self, + body, + timestamp, + message_id, + channel_id, + is_empty_message=False): + self.__body = body + self.__timestamp = timestamp + self.__channel_id = channel_id + self.__message_id = message_id + self.__is_empty_message = is_empty_message + + def __len__(self): + return len(self.__body) + + @property + def body(self): + return self.__body + + @property + def timestamp(self): + return self.__timestamp + + @property + def channel_id(self): + return self.__channel_id + + @property + def message_id(self): + return self.__message_id + + @property + def is_empty_message(self): + """Whether this message is an empty message. + Upstream DataWriter will send an empty message when this is no data + in specified interval. + """ + return self.__is_empty_message + + +class CheckpointBarrier(Message): + """ + CheckpointBarrier separates the records in the data stream into the set of + records that goes into the current snapshot, and the records that go into + the next snapshot. Each barrier carries the ID of the snapshot whose + records it pushed in front of it. + """ + + def __init__(self, barrier_data, timestamp, message_id, channel_id, + offsets, barrier_id, barrier_type): + self.__barrier_data = barrier_data + self.__timestamp = timestamp + self.__message_id = message_id + self.__channel_id = channel_id + self.checkpoint_id = barrier_id + self.offsets = offsets + self.barrier_type = barrier_type + + @property + def body(self): + return self.__barrier_data + + @property + def timestamp(self): + return self.__timestamp + + @property + def channel_id(self): + return self.__channel_id + + @property + def message_id(self): + return self.__message_id + + def get_input_checkpoints(self): + return self.offsets + + def __str__(self): + return "Barrier(Checkpoint id : {})".format(self.checkpoint_id) + + +class ChannelCreationParametersBuilder: + """ + wrap initial parameters needed by a streaming queue + """ + _java_reader_async_function_descriptor = JavaFunctionDescriptor( + "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage", + "([B)V") + _java_reader_sync_function_descriptor = JavaFunctionDescriptor( + "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync", + "([B)[B") + _java_writer_async_function_descriptor = JavaFunctionDescriptor( + "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage", + "([B)V") + _java_writer_sync_function_descriptor = JavaFunctionDescriptor( + "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync", + "([B)[B") + _python_reader_async_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.worker", "on_reader_message", "JobWorker") + _python_reader_sync_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.worker", "on_reader_message_sync", "JobWorker") + _python_writer_async_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.worker", "on_writer_message", "JobWorker") + _python_writer_sync_function_descriptor = PythonFunctionDescriptor( + "ray.streaming.runtime.worker", "on_writer_message_sync", "JobWorker") + + def get_parameters(self): + return self._parameters + + def __init__(self): + self._parameters = [] + + def build_input_queue_parameters(self, from_actors): + self.build_parameters(from_actors, + self._java_writer_async_function_descriptor, + self._java_writer_sync_function_descriptor, + self._python_writer_async_function_descriptor, + self._python_writer_sync_function_descriptor) + return self + + def build_output_queue_parameters(self, to_actors): + self.build_parameters(to_actors, + self._java_reader_async_function_descriptor, + self._java_reader_sync_function_descriptor, + self._python_reader_async_function_descriptor, + self._python_reader_sync_function_descriptor) + return self + + def build_parameters(self, actors, java_async_func, java_sync_func, + py_async_func, py_sync_func): + for handle in actors: + parameter = None + if handle._ray_actor_language == Language.PYTHON: + parameter = _streaming.ChannelCreationParameter( + handle._ray_actor_id, py_async_func, py_sync_func) + else: + parameter = _streaming.ChannelCreationParameter( + handle._ray_actor_id, java_async_func, java_sync_func) + self._parameters.append(parameter) + return self + + @staticmethod + def set_python_writer_function_descriptor(async_function, sync_function): + ChannelCreationParametersBuilder. \ + _python_writer_async_function_descriptor = async_function + ChannelCreationParametersBuilder. \ + _python_writer_sync_function_descriptor = sync_function + + @staticmethod + def set_python_reader_function_descriptor(async_function, sync_function): + ChannelCreationParametersBuilder. \ + _python_reader_async_function_descriptor = async_function + ChannelCreationParametersBuilder. \ + _python_reader_sync_function_descriptor = sync_function + + +class DataWriter: + """Data Writer is a wrapper of streaming c++ DataWriter, which sends data + to downstream workers + """ + + def __init__(self, output_channels, to_actors: List[ActorHandle], + conf: dict): + """Get DataWriter of output channels + Args: + output_channels: output channels ids + to_actors: downstream output actors + Returns: + DataWriter + """ + assert len(output_channels) > 0 + py_output_channels = [ + channel_id_str_to_bytes(qid_str) for qid_str in output_channels + ] + creation_parameters = ChannelCreationParametersBuilder() + creation_parameters.build_output_queue_parameters(to_actors) + channel_size = conf.get(Config.CHANNEL_SIZE, + Config.CHANNEL_SIZE_DEFAULT) + py_msg_ids = [0 for _ in range(len(output_channels))] + config_bytes = _to_native_conf(conf) + is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL + self.writer = _streaming.DataWriter.create( + py_output_channels, creation_parameters.get_parameters(), + channel_size, py_msg_ids, config_bytes, is_mock) + + logger.info("create DataWriter succeed") + + def write(self, channel_id: ChannelID, item: bytes): + """Write data into native channel + Args: + channel_id: channel id + item: bytes data + Returns: + msg_id + """ + assert type(item) == bytes + msg_id = self.writer.write(channel_id.object_qid, item) + return msg_id + + def broadcast_barrier(self, checkpoint_id: int, body: bytes): + """Broadcast barriers to all downstream channels + Args: + checkpoint_id: the checkpoint_id + body: barrier payload + """ + self.writer.broadcast_barrier(checkpoint_id, body) + + def get_output_checkpoints(self) -> List[int]: + """Get output offsets of all downstream channels + Returns: + a list contains current msg_id of each downstream channel + """ + return self.writer.get_output_checkpoints() + + def clear_checkpoint(self, checkpoint_id): + logger.info("producer start to clear checkpoint, checkpoint_id={}" + .format(checkpoint_id)) + self.writer.clear_checkpoint(checkpoint_id) + + def stop(self): + logger.info("stopping channel writer.") + self.writer.stop() + # destruct DataWriter + self.writer = None + + def close(self): + logger.info("closing channel writer.") + + +class DataReader: + """Data Reader is wrapper of streaming c++ DataReader, which read data + from channels of upstream workers + """ + + def __init__(self, input_channels: List, from_actors: List[ActorHandle], + conf: dict): + """Get DataReader of input channels + Args: + input_channels: input channels + from_actors: upstream input actors + Returns: + DataReader + """ + assert len(input_channels) > 0 + py_input_channels = [ + channel_id_str_to_bytes(qid_str) for qid_str in input_channels + ] + creation_parameters = ChannelCreationParametersBuilder() + creation_parameters.build_input_queue_parameters(from_actors) + py_msg_ids = [0 for _ in range(len(input_channels))] + timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) + config_bytes = _to_native_conf(conf) + self.__queue = Queue(10000) + is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL + self.reader, queues_creation_status = _streaming.DataReader.create( + py_input_channels, creation_parameters.get_parameters(), + py_msg_ids, timer_interval, config_bytes, is_mock) + + self.__creation_status = {} + for q, status in queues_creation_status.items(): + self.__creation_status[q] = ChannelCreationStatus(status) + logger.info("create DataReader succeed, creation_status={}".format( + self.__creation_status)) + + def read(self, timeout_millis): + """Read data from channel + Args: + timeout_millis: timeout millis when there is no data in channel + for this duration + Returns: + channel item + """ + if self.__queue.empty(): + messages = self.reader.read(timeout_millis) + for message in messages: + self.__queue.put(message) + + if self.__queue.empty(): + return None + return self.__queue.get() + + def get_channel_recover_info(self): + return ChannelRecoverInfo(self.__creation_status) + + def stop(self): + logger.info("stopping Data Reader.") + self.reader.stop() + # destruct DataReader + self.reader = None + + def close(self): + logger.info("closing Data Reader.") + + +def _to_native_conf(conf): + config = streaming_pb.StreamingConfig() + if Config.STREAMING_JOB_NAME in conf: + config.job_name = conf[Config.STREAMING_JOB_NAME] + if Config.STREAMING_WORKER_NAME in conf: + config.worker_name = conf[Config.STREAMING_WORKER_NAME] + if Config.STREAMING_OP_NAME in conf: + config.op_name = conf[Config.STREAMING_OP_NAME] + # TODO set operator type + if Config.STREAMING_RING_BUFFER_CAPACITY in conf: + config.ring_buffer_capacity = \ + conf[Config.STREAMING_RING_BUFFER_CAPACITY] + if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf: + config.empty_message_interval = \ + conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL] + if Config.FLOW_CONTROL_TYPE in conf: + conf.flow_control_type = conf[Config.FLOW_CONTROL_TYPE] + if Config.WRITER_CONSUMED_STEP in conf: + conf.writer_consumed_step = \ + conf[Config.WRITER_CONSUMED_STEP] + if Config.READER_CONSUMED_STEP in conf: + conf.reader_consumed_step = \ + conf[Config.READER_CONSUMED_STEP] + logger.info("conf: %s", str(config)) + return config.SerializeToString() + + +class ChannelInitException(Exception): + def __init__(self, msg, abnormal_channels): + self.abnormal_channels = abnormal_channels + self.msg = msg + + +class ChannelInterruptException(Exception): + def __init__(self, msg=None): + self.msg = msg + + +class ChannelRecoverInfo: + def __init__(self, queue_creation_status_map=None): + if queue_creation_status_map is None: + queue_creation_status_map = {} + self.__queue_creation_status_map = queue_creation_status_map + + def get_creation_status(self): + return self.__queue_creation_status_map + + def get_data_lost_queues(self): + data_lost_queues = set() + for (q, status) in self.__queue_creation_status_map.items(): + if status == ChannelCreationStatus.DataLost: + data_lost_queues.add(q) + return data_lost_queues + + def __str__(self): + return "QueueRecoverInfo [dataLostQueues=%s]" \ + % (self.get_data_lost_queues()) + + +class ChannelCreationStatus(Enum): + FreshStarted = 0 + PullOk = 1 + Timeout = 2 + DataLost = 3 + + +def channel_id_bytes_to_str(id_bytes): + """ + Args: + id_bytes: bytes representation of channel id + + Returns: + string representation of channel id + """ + assert type(id_bytes) in [str, bytes] + if isinstance(id_bytes, str): + return id_bytes + return bytes.hex(id_bytes) diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py new file mode 100644 index 00000000..d6d8eb02 --- /dev/null +++ b/streaming/python/runtime/worker.py @@ -0,0 +1,386 @@ +import enum +import logging.config +import os +import threading +import time +from typing import Optional + +import ray +import ray.streaming.runtime.processor as processor +from ray.actor import ActorHandle +from ray.streaming.generated import remote_call_pb2 +from ray.streaming.runtime.command import WorkerRollbackRequest +from ray.streaming.runtime.failover import Barrier +from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex +from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst +from ray.streaming.runtime.context_backend import ContextBackendFactory +from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask +from ray.streaming.runtime.transfer import channel_bytes_to_str +from ray.streaming.config import Config +import ray.streaming._streaming as _streaming + +logger = logging.getLogger(__name__) + +# special flag to indicate this actor not ready +_NOT_READY_FLAG_ = b" " * 4 + + +@ray.remote +class JobWorker(object): + """A streaming job worker is used to execute user-defined function and + interact with `JobMaster`""" + master_actor: Optional[ActorHandle] + worker_context: Optional[remote_call_pb2.PythonJobWorkerContext] + execution_vertex_context: Optional[ExecutionVertexContext] + __need_rollback: bool + + def __init__(self, execution_vertex_pb_bytes): + logger.info("Creating job worker, pid={}".format(os.getpid())) + execution_vertex_pb = remote_call_pb2\ + .ExecutionVertexContext.ExecutionVertex() + execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes) + self.execution_vertex = ExecutionVertex(execution_vertex_pb) + self.config = self.execution_vertex.config + self.worker_context = None + self.execution_vertex_context = None + self.task_id = None + self.task = None + self.stream_processor = None + self.master_actor = None + self.context_backend = ContextBackendFactory.get_context_backend( + self.config) + self.initial_state_lock = threading.Lock() + self.__rollback_cnt: int = 0 + self.__is_recreate: bool = False + self.__state = WorkerState() + self.__need_rollback = True + self.reader_client = None + self.writer_client = None + try: + # load checkpoint + was_reconstructed = ray.get_runtime_context( + ).was_current_actor_reconstructed + + logger.info( + "Worker was reconstructed: {}".format(was_reconstructed)) + if was_reconstructed: + job_worker_context_key = self.__get_job_worker_context_key() + logger.info("Worker get checkpoint state by key: {}.".format( + job_worker_context_key)) + context_bytes = self.context_backend.get( + job_worker_context_key) + if context_bytes is not None and context_bytes.__len__() > 0: + self.init(context_bytes) + self.request_rollback( + "Python worker recover from checkpoint.") + else: + logger.error( + "Error! Worker get checkpoint state by key {}" + " returns None, please check your state backend" + ", only reliable state backend supports fail-over." + .format(job_worker_context_key)) + except Exception: + logger.exception("Error in __init__ of JobWorker") + logger.info("Creating job worker succeeded. worker config {}".format( + self.config)) + + def init(self, worker_context_bytes): + logger.info("Start to init job worker") + try: + # deserialize context + worker_context = remote_call_pb2.PythonJobWorkerContext() + worker_context.ParseFromString(worker_context_bytes) + self.worker_context = worker_context + self.master_actor = ActorHandle._deserialization_helper( + worker_context.master_actor) + + # build vertex context from pb + self.execution_vertex_context = ExecutionVertexContext( + worker_context.execution_vertex_context) + self.execution_vertex = self\ + .execution_vertex_context.execution_vertex + + # save context + job_worker_context_key = self.__get_job_worker_context_key() + self.context_backend.put(job_worker_context_key, + worker_context_bytes) + + # use vertex id as task id + self.task_id = self.execution_vertex_context.get_task_id() + # build and get processor from operator + operator = self.execution_vertex_context.stream_operator + self.stream_processor = processor.build_processor(operator) + logger.info("Initializing job worker, exe_vertex_name={}," + "task_id: {}, operator: {}, pid={}".format( + self.execution_vertex_context.exe_vertex_name, + self.task_id, self.stream_processor, os.getpid())) + + # get config from vertex + self.config = self.execution_vertex_context.config + + if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): + self.reader_client = _streaming.ReaderClient() + self.writer_client = _streaming.WriterClient() + + logger.info("Job worker init succeeded.") + except Exception: + logger.exception("Error when init job worker.") + return False + return True + + def create_stream_task(self, checkpoint_id): + if isinstance(self.stream_processor, processor.SourceProcessor): + return SourceStreamTask(self.task_id, self.stream_processor, self, + checkpoint_id) + elif isinstance(self.stream_processor, processor.OneInputProcessor): + return OneInputStreamTask(self.task_id, self.stream_processor, + self, checkpoint_id) + else: + raise Exception("Unsupported processor type: " + + str(type(self.stream_processor))) + + def rollback(self, checkpoint_id_bytes): + checkpoint_id_pb = remote_call_pb2.CheckpointId() + checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) + checkpoint_id = checkpoint_id_pb.checkpoint_id + + logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id)) + + self.__rollback_cnt += 1 + if self.__rollback_cnt > 1: + self.__is_recreate = True + # skip useless rollback + self.initial_state_lock.acquire() + try: + if self.task is not None and self.task.thread.is_alive()\ + and checkpoint_id == self.task.last_checkpoint_id\ + and self.task.is_initial_state: + logger.info( + "Task is already in initial state, skip this rollback.") + return self.__gen_call_result( + CallResult.skipped( + "Task is already in initial state, skip this rollback." + )) + finally: + self.initial_state_lock.release() + + # restart task + try: + if self.task is not None: + # make sure the runner is closed + self.task.cancel_task() + del self.task + + self.task = self.create_stream_task(checkpoint_id) + + q_recover_info = self.task.recover(self.__is_recreate) + + self.__state.set_type(StateType.RUNNING) + self.__need_rollback = False + + logger.info( + "Rollback success, checkpoint is {}, qRecoverInfo is {}.". + format(checkpoint_id, q_recover_info)) + + return self.__gen_call_result(CallResult.success(q_recover_info)) + except Exception: + logger.exception("Rollback has exception.") + return self.__gen_call_result(CallResult.fail()) + + def on_reader_message(self, *buffers): + """Called by upstream queue writer to send data message to downstream + queue reader. + """ + if self.reader_client is None: + logger.info("reader_client is None, skip writer transfer") + return + self.reader_client.on_reader_message(*buffers) + + def on_reader_message_sync(self, buffer: bytes): + """Called by upstream queue writer to send + control message to downstream downstream queue reader. + """ + if self.reader_client is None: + logger.info("task is None, skip reader transfer") + return _NOT_READY_FLAG_ + result = self.reader_client.on_reader_message_sync(buffer) + return result.to_pybytes() + + def on_writer_message(self, buffer: bytes): + """Called by downstream queue reader to send notify message to + upstream queue writer. + """ + if self.writer_client is None: + logger.info("writer_client is None, skip writer transfer") + return + self.writer_client.on_writer_message(buffer) + + def on_writer_message_sync(self, buffer: bytes): + """Called by downstream queue reader to send control message to + upstream queue writer. + """ + if self.writer_client is None: + return _NOT_READY_FLAG_ + result = self.writer_client.on_writer_message_sync(buffer) + return result.to_pybytes() + + def shutdown_without_reconstruction(self): + logger.info("Python worker shutdown without reconstruction.") + ray.actor.exit_actor() + + def notify_checkpoint_timeout(self, checkpoint_id_bytes): + pass + + def commit(self, barrier_bytes): + barrier_pb = remote_call_pb2.Barrier() + barrier_pb.ParseFromString(barrier_bytes) + barrier = Barrier(barrier_pb.id) + logger.info("Receive trigger, barrier is {}.".format(barrier)) + + if self.task is not None: + self.task.commit_trigger(barrier) + ret = remote_call_pb2.BoolResult() + ret.boolRes = True + return ret.SerializeToString() + + def clear_expired_cp(self, state_checkpoint_id_bytes, + queue_checkpoint_id_bytes): + state_checkpoint_id = self.__parse_to_checkpoint_id( + state_checkpoint_id_bytes) + queue_checkpoint_id = self.__parse_to_checkpoint_id( + queue_checkpoint_id_bytes) + logger.info("Start to clear expired checkpoint, checkpoint_id={}," + "queue_checkpoint_id={}, exe_vertex_name={}.".format( + state_checkpoint_id, queue_checkpoint_id, + self.execution_vertex_context.exe_vertex_name)) + + ret = remote_call_pb2.BoolResult() + ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \ + if state_checkpoint_id > 0 else True + ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id) + logger.info( + "Clear expired checkpoint done, result={}, checkpoint_id={}," + "queue_checkpoint_id={}, exe_vertex_name={}.".format( + ret.boolRes, state_checkpoint_id, queue_checkpoint_id, + self.execution_vertex_context.exe_vertex_name)) + return ret.SerializeToString() + + def __clear_expired_cp_state(self, checkpoint_id): + if self.__need_rollback: + logger.warning("Need rollback, skip clear_expired_cp_state" + ", checkpoint id: {}".format(checkpoint_id)) + return False + + logger.info("Clear expired checkpoint state, cp id is {}.".format( + checkpoint_id)) + + if self.task is not None: + self.task.clear_expired_cp_state(checkpoint_id) + return True + + def __clear_expired_queue_msg(self, checkpoint_id): + if self.__need_rollback: + logger.warning("Need rollback, skip clear_expired_queue_msg" + ", checkpoint id: {}".format(checkpoint_id)) + return False + + logger.info("Clear expired queue msg, checkpoint_id is {}.".format( + checkpoint_id)) + + if self.task is not None: + self.task.clear_expired_queue_msg(checkpoint_id) + return True + + def __parse_to_checkpoint_id(self, checkpoint_id_bytes): + checkpoint_id_pb = remote_call_pb2.CheckpointId() + checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) + return checkpoint_id_pb.checkpoint_id + + def check_if_need_rollback(self): + ret = remote_call_pb2.BoolResult() + ret.boolRes = self.__need_rollback + return ret.SerializeToString() + + def request_rollback(self, exception_msg="Python exception."): + logger.info("Request rollback.") + + self.__need_rollback = True + self.__is_recreate = True + + request_ret = False + for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES): + logger.info("request rollback {} time".format(i)) + try: + request_ret = RemoteCallMst.request_job_worker_rollback( + self.master_actor, + WorkerRollbackRequest( + self.execution_vertex_context.actor_id.binary(), + "Exception msg=%s, retry time=%d." % (exception_msg, + i))) + except Exception: + logger.exception("Unexpected error when rollback") + logger.info("request rollback {} time, ret={}".format( + i, request_ret)) + if not request_ret: + logger.warning( + "Request rollback return false" + ", maybe it's invalid request, try to sleep 1s.") + time.sleep(1) + else: + break + if not request_ret: + logger.warning("Request failed after retry {} times," + "now worker shutdown without reconstruction." + .format(Config.REQUEST_ROLLBACK_RETRY_TIMES)) + self.shutdown_without_reconstruction() + + self.__state.set_type(StateType.WAIT_ROLLBACK) + + def __gen_call_result(self, call_result): + call_result_pb = remote_call_pb2.CallResult() + + call_result_pb.success = call_result.success + call_result_pb.result_code = call_result.result_code.value + if call_result.result_msg is not None: + call_result_pb.result_msg = call_result.result_msg + + if call_result.result_obj is not None: + q_recover_info = call_result.result_obj + for q, status in q_recover_info.get_creation_status().items(): + call_result_pb.result_obj.creation_status[channel_bytes_to_str( + q)] = status.value + + return call_result_pb.SerializeToString() + + def _gen_unique_key(self, key_prefix): + return key_prefix \ + + str(self.config.get(Config.STREAMING_JOB_NAME)) \ + + "_" + str(self.execution_vertex.execution_vertex_id) + + def __get_job_worker_context_key(self) -> str: + return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY) + + +class WorkerState: + """ + worker state + """ + + def __init__(self): + self.__type = StateType.INIT + + def set_type(self, type): + self.__type = type + + def get_type(self): + return self.__type + + +class StateType(enum.Enum): + """ + state type + """ + + INIT = 1 + RUNNING = 2 + WAIT_ROLLBACK = 3 diff --git a/streaming/python/tests/__init__.py b/streaming/python/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py new file mode 100644 index 00000000..9a9f2892 --- /dev/null +++ b/streaming/python/tests/test_direct_transfer.py @@ -0,0 +1,122 @@ +import pickle +import threading +import time + +import ray +import ray.streaming._streaming as _streaming +import ray.streaming.runtime.transfer as transfer +from ray._raylet import PythonFunctionDescriptor +from ray.streaming.config import Config + + +@ray.remote +class Worker: + def __init__(self): + self.writer_client = _streaming.WriterClient() + self.reader_client = _streaming.ReaderClient() + self.writer = None + self.output_channel_id = None + self.reader = None + + def init_writer(self, output_channel, reader_actor): + conf = {Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL} + reader_async_func = PythonFunctionDescriptor( + __name__, self.on_reader_message.__name__, self.__class__.__name__) + reader_sync_func = PythonFunctionDescriptor( + __name__, self.on_reader_message_sync.__name__, + self.__class__.__name__) + transfer.ChannelCreationParametersBuilder.\ + set_python_reader_function_descriptor( + reader_async_func, reader_sync_func) + self.writer = transfer.DataWriter([output_channel], + [pickle.loads(reader_actor)], conf) + self.output_channel_id = transfer.ChannelID(output_channel) + + def init_reader(self, input_channel, writer_actor): + conf = {Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL} + writer_async_func = PythonFunctionDescriptor( + __name__, self.on_writer_message.__name__, self.__class__.__name__) + writer_sync_func = PythonFunctionDescriptor( + __name__, self.on_writer_message_sync.__name__, + self.__class__.__name__) + transfer.ChannelCreationParametersBuilder.\ + set_python_writer_function_descriptor( + writer_async_func, writer_sync_func) + self.reader = transfer.DataReader([input_channel], + [pickle.loads(writer_actor)], conf) + + def start_write(self, msg_nums): + self.t = threading.Thread( + target=self.run_writer, args=[msg_nums], daemon=True) + self.t.start() + + def run_writer(self, msg_nums): + for i in range(msg_nums): + self.writer.write(self.output_channel_id, pickle.dumps(i)) + print("WriterWorker done.") + + def start_read(self, msg_nums): + self.t = threading.Thread( + target=self.run_reader, args=[msg_nums], daemon=True) + self.t.start() + + def run_reader(self, msg_nums): + count = 0 + msg = None + while count != msg_nums: + item = self.reader.read(100) + if item is None: + time.sleep(0.01) + else: + msg = pickle.loads(item.body) + count += 1 + assert msg == msg_nums - 1 + print("ReaderWorker done.") + + def is_finished(self): + return not self.t.is_alive() + + def on_reader_message(self, buffer: bytes): + """used in direct call mode""" + self.reader_client.on_reader_message(buffer) + + def on_reader_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.reader_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.reader_client.on_reader_message_sync(buffer) + return result.to_pybytes() + + def on_writer_message(self, buffer: bytes): + """used in direct call mode""" + self.writer_client.on_writer_message(buffer) + + def on_writer_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.writer_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.writer_client.on_writer_message_sync(buffer) + return result.to_pybytes() + + +def test_queue(): + ray.init() + writer = Worker._remote() + reader = Worker._remote() + channel_id_str = transfer.ChannelID.gen_random_id() + inits = [ + writer.init_writer.remote(channel_id_str, pickle.dumps(reader)), + reader.init_reader.remote(channel_id_str, pickle.dumps(writer)) + ] + ray.get(inits) + msg_nums = 1000 + print("start read/write") + reader.start_read.remote(msg_nums) + writer.start_write.remote(msg_nums) + while not ray.get(reader.is_finished.remote()): + time.sleep(0.1) + ray.shutdown() + + +if __name__ == "__main__": + test_queue() diff --git a/streaming/python/tests/test_failover.py b/streaming/python/tests/test_failover.py new file mode 100644 index 00000000..264a4009 --- /dev/null +++ b/streaming/python/tests/test_failover.py @@ -0,0 +1,109 @@ +import subprocess +import sys +import time +from typing import List + +import ray +from ray.streaming import StreamingContext + + +def test_word_count(): + try: + ray.init( + job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + # time.sleep(10) # for gdb to attach + ctx = StreamingContext.Builder() \ + .option("streaming.context-backend.type", "local_file") \ + .option( + "streaming.context-backend.file-state.root", + "/tmp/ray/cp_files/" + ) \ + .option("streaming.checkpoint.timeout.secs", "3") \ + .build() + + print("-----------submit job-------------") + + ctx.read_text_file(__file__) \ + .set_parallelism(1) \ + .flat_map(lambda x: x.split()) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .filter(lambda x: "ray" not in x) \ + .sink(lambda x: print("####result", x)) + ctx.submit("word_count") + + print("-----------checking output-------------") + retry_count = 180 / 5 # wait for 3min + while not has_sink_output(): + time.sleep(5) + retry_count -= 1 + if retry_count <= 0: + raise RuntimeError("Can not find output") + + print("-----------killing worker-------------") + time.sleep(5) + kill_all_worker() + + print("-----------checking checkpoint-------------") + cp_ok_num = checkpoint_success_num() + retry_count = 300000 / 5 # wait for 5min + while True: + cur_cp_num = checkpoint_success_num() + print("-----------checking checkpoint" + ", cur_cp_num={}, old_cp_num={}-------------".format( + cur_cp_num, cp_ok_num)) + if cur_cp_num > cp_ok_num: + print("--------------TEST OK!------------------") + break + time.sleep(5) + retry_count -= 1 + if retry_count <= 0: + raise RuntimeError( + "Checkpoint keeps failing after fail-over, test failed!") + finally: + ray.shutdown() + + +def run_cmd(cmd: List): + try: + out = subprocess.check_output(cmd).decode() + except subprocess.CalledProcessError as e: + out = str(e) + return out + + +def grep_log(keyword: str) -> str: + out = subprocess.check_output( + ["grep", "-r", keyword, "/tmp/ray/session_latest/logs"]) + return out.decode() + + +def has_sink_output() -> bool: + try: + grep_log("####result") + return True + except Exception: + return False + + +def checkpoint_success_num() -> int: + try: + return grep_log("Finish checkpoint").count("\n") + except Exception: + return 0 + + +def kill_all_worker(): + cmd = [ + "bash", "-c", "grep -r \'Initializing job worker, exe_vert\' " + " /tmp/ray/session_latest/logs | awk -F\'pid\' \'{print $2}\'" + "| awk -F\'=\' \'{print $2}\'" + "| xargs kill -9" + ] + print(cmd) + return subprocess.run(cmd) + + +if __name__ == "__main__": + test_word_count() diff --git a/streaming/python/tests/test_function.py b/streaming/python/tests/test_function.py new file mode 100644 index 00000000..c9ce3306 --- /dev/null +++ b/streaming/python/tests/test_function.py @@ -0,0 +1,22 @@ +from ray.streaming import function +from ray.streaming.runtime import gateway_client + + +def test_get_simple_function_class(): + simple_map_func_class = function._get_simple_function_class( + function.MapFunction) + assert simple_map_func_class is function.SimpleMapFunction + + +class MapFunc(function.MapFunction): + def map(self, value): + return str(value) + + +def test_load_function(): + # function_bytes, module_name, function_name/class_name, + # function_interface + descriptor_func_bytes = gateway_client.serialize( + [None, __name__, MapFunc.__name__, "MapFunction"]) + func = function.load_function(descriptor_func_bytes) + assert type(func) is MapFunc diff --git a/streaming/python/tests/test_hybrid_stream.py b/streaming/python/tests/test_hybrid_stream.py new file mode 100644 index 00000000..9fa3dfdd --- /dev/null +++ b/streaming/python/tests/test_hybrid_stream.py @@ -0,0 +1,73 @@ +import os +import subprocess + +import ray +from ray.streaming import StreamingContext +from ray._private.test_utils import wait_for_condition + + +def map_func1(x): + print("HybridStreamTest map_func1", x) + return str(x) + + +def filter_func1(x): + print("HybridStreamTest filter_func1", x) + return "b" not in x + + +def sink_func1(x): + print("HybridStreamTest sink_func1 value:", x) + + +def test_hybrid_stream(): + subprocess.check_call( + ["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"]) + current_dir = os.path.abspath(os.path.dirname(__file__)) + jar_path = os.path.join( + current_dir, + "../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar") + jar_path = os.path.abspath(jar_path) + print("jar_path", jar_path) + assert not ray.is_initialized() + ray.init(job_config=ray.job_config.JobConfig(code_search_path=[jar_path])) + + sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt" + if os.path.exists(sink_file): + os.remove(sink_file) + + def sink_func(x): + print("HybridStreamTest", x) + with open(sink_file, "a") as f: + f.write(str(x)) + f.flush() + + ctx = StreamingContext.Builder().build() + ctx.from_values("a", "b", "c") \ + .as_java_stream() \ + .map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \ + .filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \ + .as_python_stream() \ + .sink(sink_func) + ctx.submit("HybridStreamTest") + + def check_succeed(): + if os.path.exists(sink_file): + import time + time.sleep(3) # Wait all data be written + with open(sink_file, "r") as f: + result = f.read() + assert "a" in result + assert "b" not in result + assert "c" in result + print("Execution succeed") + return True + return False + + wait_for_condition(check_succeed, timeout=60, retry_interval_ms=1000) + print("Execution succeed") + ray.shutdown() + + +if __name__ == "__main__": + test_hybrid_stream() diff --git a/streaming/python/tests/test_operator.py b/streaming/python/tests/test_operator.py new file mode 100644 index 00000000..93782212 --- /dev/null +++ b/streaming/python/tests/test_operator.py @@ -0,0 +1,37 @@ +from ray.streaming import function +from ray.streaming import operator +from ray.streaming.operator import OperatorType +from ray.streaming.runtime import gateway_client + + +def test_create_operator_with_func(): + map_func = function.SimpleMapFunction(lambda x: x) + map_operator = operator.create_operator_with_func(map_func) + assert type(map_operator) is operator.MapOperator + + +class MapFunc(function.MapFunction): + def map(self, value): + return str(value) + + +class EmptyOperator(operator.StreamOperator): + def __init__(self): + super().__init__(function.EmptyFunction()) + + def operator_type(self) -> OperatorType: + return OperatorType.ONE_INPUT + + +def test_load_operator(): + # function_bytes, module_name, class_name, + descriptor_func_bytes = gateway_client.serialize( + [None, __name__, MapFunc.__name__, "MapFunction"]) + descriptor_op_bytes = gateway_client.serialize( + [descriptor_func_bytes, "", ""]) + map_operator = operator.load_operator(descriptor_op_bytes) + assert type(map_operator) is operator.MapOperator + descriptor_op_bytes = gateway_client.serialize( + [None, __name__, EmptyOperator.__name__]) + test_operator = operator.load_operator(descriptor_op_bytes) + assert isinstance(test_operator, EmptyOperator) diff --git a/streaming/python/tests/test_serialization.py b/streaming/python/tests/test_serialization.py new file mode 100644 index 00000000..67865f80 --- /dev/null +++ b/streaming/python/tests/test_serialization.py @@ -0,0 +1,13 @@ +from ray.streaming.runtime.serialization import CrossLangSerializer +from ray.streaming.message import Record, KeyRecord + + +def test_serialize(): + serializer = CrossLangSerializer() + record = Record("value") + record.stream = "stream1" + key_record = KeyRecord("key", "value") + key_record.stream = "stream2" + assert record == serializer.deserialize(serializer.serialize(record)) + assert key_record == serializer.\ + deserialize(serializer.serialize(key_record)) diff --git a/streaming/python/tests/test_stream.py b/streaming/python/tests/test_stream.py new file mode 100644 index 00000000..27febb82 --- /dev/null +++ b/streaming/python/tests/test_stream.py @@ -0,0 +1,51 @@ +import sys + +import ray +from ray.streaming import StreamingContext + + +def test_data_stream(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder().build() + stream = ctx.from_values(1, 2, 3) + java_stream = stream.as_java_stream() + python_stream = java_stream.as_python_stream() + assert stream.get_id() == java_stream.get_id() + assert stream.get_id() == python_stream.get_id() + python_stream.set_parallelism(10) + assert stream.get_parallelism() == java_stream.get_parallelism() + assert stream.get_parallelism() == python_stream.get_parallelism() + ray.shutdown() + + +def test_key_data_stream(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder().build() + key_stream = ctx.from_values( + "a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0]) + java_stream = key_stream.as_java_stream() + python_stream = java_stream.as_python_stream() + assert key_stream.get_id() == java_stream.get_id() + assert key_stream.get_id() == python_stream.get_id() + python_stream.set_parallelism(10) + assert key_stream.get_parallelism() == java_stream.get_parallelism() + assert key_stream.get_parallelism() == python_stream.get_parallelism() + ray.shutdown() + + +def test_stream_config(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder().build() + stream = ctx.from_values(1, 2, 3) + stream.with_config("k1", "v1") + print("config", stream.get_config()) + assert stream.get_config() == {"k1": "v1"} + stream.with_config(conf={"k2": "v2", "k3": "v3"}) + print("config", stream.get_config()) + assert stream.get_config() == {"k1": "v1", "k2": "v2", "k3": "v3"} + java_stream = stream.as_java_stream() + java_stream.with_config(conf={"k4": "v4"}) + config = java_stream.get_config() + print("config", config) + assert config == {"k1": "v1", "k2": "v2", "k3": "v3", "k4": "v4"} + ray.shutdown() diff --git a/streaming/python/tests/test_union_stream.py b/streaming/python/tests/test_union_stream.py new file mode 100644 index 00000000..bab75e62 --- /dev/null +++ b/streaming/python/tests/test_union_stream.py @@ -0,0 +1,48 @@ +import os +import sys + +import ray +from ray.streaming import StreamingContext + + +def test_union_stream(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder() \ + .option("streaming.metrics.reporters", "") \ + .build() + sink_file = "/tmp/test_union_stream.txt" + if os.path.exists(sink_file): + os.remove(sink_file) + + def sink_func(x): + with open(sink_file, "a") as f: + print("sink_func", x) + f.write(str(x)) + + stream1 = ctx.from_values(1, 2) + stream2 = ctx.from_values(3, 4) + stream3 = ctx.from_values(5, 6) + stream1.union(stream2, stream3).sink(sink_func) + ctx.submit("test_union_stream") + import time + slept_time = 0 + while True: + if os.path.exists(sink_file): + time.sleep(3) + with open(sink_file, "r") as f: + result = f.read() + print("sink result", result) + assert set(result) == {"1", "2", "3", "4", "5", "6"} + print("Execution succeed") + break + if slept_time >= 60: + raise Exception("Execution not finished") + slept_time = slept_time + 1 + print("Wait finish...") + time.sleep(1) + + ray.shutdown() + + +if __name__ == "__main__": + test_union_stream() diff --git a/streaming/python/tests/test_word_count.py b/streaming/python/tests/test_word_count.py new file mode 100644 index 00000000..e38fc5ca --- /dev/null +++ b/streaming/python/tests/test_word_count.py @@ -0,0 +1,65 @@ +import os +import sys +import ray +from ray.streaming import StreamingContext +from ray._private.test_utils import wait_for_condition + + +def test_word_count(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder() \ + .build() + ctx.read_text_file(__file__) \ + .set_parallelism(1) \ + .flat_map(lambda x: x.split()) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .filter(lambda x: "ray" not in x) \ + .sink(lambda x: print("result", x)) + ctx.submit("word_count") + import time + time.sleep(3) + ray.shutdown() + + +def test_simple_word_count(): + ray.init(job_config=ray.job_config.JobConfig(code_search_path=sys.path)) + ctx = StreamingContext.Builder() \ + .build() + sink_file = "/tmp/ray_streaming_test_simple_word_count.txt" + if os.path.exists(sink_file): + os.remove(sink_file) + + def sink_func(x): + with open(sink_file, "a") as f: + line = "{}:{},".format(x[0], x[1]) + print("sink_func", line) + f.write(line) + + ctx.from_values("a", "b", "c") \ + .set_parallelism(1) \ + .flat_map(lambda x: [x, x]) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .sink(sink_func) + ctx.submit("word_count") + + def check_succeed(): + if os.path.exists(sink_file): + with open(sink_file, "r") as f: + result = f.read() + return "a:2" in result and "b:2" in result and "c:2" in result + return False + + wait_for_condition(check_succeed, timeout=60, retry_interval_ms=1000) + print("Execution succeed") + ray.shutdown() + + +if __name__ == "__main__": + test_word_count() + test_simple_word_count() diff --git a/streaming/src/channel/channel.cc b/streaming/src/channel/channel.cc new file mode 100644 index 00000000..896ea216 --- /dev/null +++ b/streaming/src/channel/channel.cc @@ -0,0 +1,324 @@ +#include "channel.h" + +#include +namespace ray { +namespace streaming { + +ProducerChannel::ProducerChannel(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info) + : transfer_config_(transfer_config), channel_info_(p_channel_info) {} + +ConsumerChannel::ConsumerChannel(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : transfer_config_(transfer_config), channel_info_(c_channel_info) {} + +StreamingQueueProducer::StreamingQueueProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info) + : ProducerChannel(transfer_config, p_channel_info) { + STREAMING_LOG(INFO) << "Producer Init"; +} + +StreamingQueueProducer::~StreamingQueueProducer() { + STREAMING_LOG(INFO) << "Producer Destory"; +} + +StreamingStatus StreamingQueueProducer::CreateTransferChannel() { + CreateQueue(); + + STREAMING_LOG(WARNING) << "Message id in channel => " + << channel_info_.current_message_id; + + channel_info_.message_last_commit_id = 0; + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::CreateQueue() { + STREAMING_LOG(INFO) << "CreateQueue qid: " << channel_info_.channel_id + << " data_size: " << channel_info_.queue_size; + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + if (upstream_handler->UpstreamQueueExists(channel_info_.channel_id)) { + STREAMING_LOG(INFO) << "StreamingQueueProducer CreateQueue duplicate."; + return StreamingStatus::OK; + } + + upstream_handler->SetPeerActorID( + channel_info_.channel_id, channel_info_.parameter.actor_id, + *channel_info_.parameter.async_function, *channel_info_.parameter.sync_function); + queue_ = upstream_handler->CreateUpstreamQueue(channel_info_.channel_id, + channel_info_.parameter.actor_id, + channel_info_.queue_size); + STREAMING_CHECK(queue_ != nullptr); + + STREAMING_LOG(INFO) << "StreamingQueueProducer CreateQueue queue id => " + << channel_info_.channel_id << ", queue size => " + << channel_info_.queue_size; + + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::DestroyTransferChannel() { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint( + uint64_t checkpoint_id, uint64_t checkpoint_offset) { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::RefreshChannelInfo() { + channel_info_.queue_info.consumed_message_id = queue_->GetMinConsumedMsgID(); + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t msg_id) { + queue_->SetQueueEvictionLimit(msg_id); + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, + uint32_t data_size) { + StreamingMessageBundleMetaPtr meta = StreamingMessageBundleMeta::FromBytes(data); + uint64_t msg_id_end = meta->GetLastMessageId(); + uint64_t msg_id_start = + (meta->GetMessageListSize() == 0 ? msg_id_end + : msg_id_end - meta->GetMessageListSize() + 1); + + STREAMING_LOG(DEBUG) << "ProduceItemToChannel, qid=" << channel_info_.channel_id + << ", msg_id_start=" << msg_id_start + << ", msg_id_end=" << msg_id_end << ", meta=" << *meta; + + Status status = + PushQueueItem(data, data_size, current_time_ms(), msg_id_start, msg_id_end); + if (status.code() != StatusCode::OK) { + STREAMING_LOG(DEBUG) << channel_info_.channel_id << " => Queue is full" + << " meesage => " << status.message(); + + // Assume that only status OutOfMemory and OK are acceptable. + // OutOfMemory means queue is full at that moment. + STREAMING_CHECK(status.code() == StatusCode::OutOfMemory) + << "status => " << status.message() + << ", perhaps data block is so large that it can't be stored in" + << ", data block size => " << data_size; + + return StreamingStatus::FullChannel; + } + return StreamingStatus::OK; +} + +Status StreamingQueueProducer::PushQueueItem(uint8_t *data, uint32_t data_size, + uint64_t timestamp, uint64_t msg_id_start, + uint64_t msg_id_end) { + STREAMING_LOG(DEBUG) << "StreamingQueueProducer::PushQueueItem:" + << " qid: " << channel_info_.channel_id + << " data_size: " << data_size; + Status status = + queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); + if (status.IsOutOfMemory()) { + status = queue_->TryEvictItems(); + if (!status.ok()) { + STREAMING_LOG(INFO) << "Evict fail."; + return status; + } + + status = queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); + } + + queue_->Send(); + return status; +} + +StreamingQueueConsumer::StreamingQueueConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : ConsumerChannel(transfer_config, c_channel_info) { + STREAMING_LOG(INFO) << "Consumer Init"; +} + +StreamingQueueConsumer::~StreamingQueueConsumer() { + STREAMING_LOG(INFO) << "Consumer Destroy"; +} + +StreamingQueueStatus StreamingQueueConsumer::GetQueue( + const ObjectID &queue_id, uint64_t start_msg_id, + const ChannelCreationParameter &init_param) { + STREAMING_LOG(INFO) << "GetQueue qid: " << queue_id << " start_msg_id: " << start_msg_id + << " actor_id: " << init_param.actor_id; + auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService(); + if (downstream_handler->DownstreamQueueExists(queue_id)) { + STREAMING_LOG(INFO) << "StreamingQueueReader:: Already got this queue."; + return StreamingQueueStatus::OK; + } + + downstream_handler->SetPeerActorID(queue_id, channel_info_.parameter.actor_id, + *init_param.async_function, + *init_param.sync_function); + STREAMING_LOG(INFO) << "Create ReaderQueue " << queue_id + << " pull from start_msg_id: " << start_msg_id; + queue_ = downstream_handler->CreateDownstreamQueue(queue_id, init_param.actor_id); + STREAMING_CHECK(queue_ != nullptr); + + bool is_first_pull; + return downstream_handler->PullQueue(queue_id, start_msg_id, is_first_pull); +} + +TransferCreationStatus StreamingQueueConsumer::CreateTransferChannel() { + StreamingQueueStatus status = + GetQueue(channel_info_.channel_id, channel_info_.current_message_id + 1, + channel_info_.parameter); + + if (status == StreamingQueueStatus::OK) { + return TransferCreationStatus::PullOk; + } else if (status == StreamingQueueStatus::NoValidData) { + return TransferCreationStatus::FreshStarted; + } else if (status == StreamingQueueStatus::Timeout) { + return TransferCreationStatus::Timeout; + } else if (status == StreamingQueueStatus::DataLost) { + return TransferCreationStatus::DataLost; + } + STREAMING_LOG(FATAL) << "Invalid StreamingQueueStatus, status=" << status; + return TransferCreationStatus::Invalid; +} + +StreamingStatus StreamingQueueConsumer::DestroyTransferChannel() { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint( + uint64_t checkpoint_id, uint64_t checkpoint_offset) { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::RefreshChannelInfo() { + channel_info_.queue_info.last_message_id = queue_->GetLastRecvMsgId(); + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint8_t *&data, + uint32_t &data_size, + uint32_t timeout) { + STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info_.channel_id; + STREAMING_CHECK(queue_ != nullptr); + QueueItem item = queue_->PopPendingBlockTimeout(timeout * 1000); + if (item.SeqId() == QUEUE_INVALID_SEQ_ID) { + STREAMING_LOG(INFO) << "GetQueueItem timeout."; + data = nullptr; + data_size = 0; + return StreamingStatus::OK; + } + + data = item.Buffer()->Data(); + data_size = item.Buffer()->Size(); + + STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info_.channel_id + << " seq_id: " << item.SeqId() << " msg_id: " << item.MaxMsgId() + << " data_size: " << data_size; + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::NotifyChannelConsumed(uint64_t offset_id) { + STREAMING_CHECK(queue_ != nullptr); + queue_->OnConsumed(offset_id); + return StreamingStatus::OK; +} + +// For mock queue transfer +struct MockQueueItem { + uint64_t seq_id; + uint32_t data_size; + std::shared_ptr data; +}; + +class MockQueue { + public: + std::unordered_map>> + message_buffer; + std::unordered_map>> + consumed_buffer; + std::unordered_map queue_info_map; + static std::mutex mutex; + static MockQueue &GetMockQueue() { + static MockQueue mock_queue; + return mock_queue; + } +}; +std::mutex MockQueue::mutex; + +StreamingStatus MockProducer::CreateTransferChannel() { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); + mock_queue.message_buffer[channel_info_.channel_id] = + std::make_shared>(10000); + mock_queue.consumed_buffer[channel_info_.channel_id] = + std::make_shared>(10000); + return StreamingStatus::OK; +} + +StreamingStatus MockProducer::DestroyTransferChannel() { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); + mock_queue.message_buffer.erase(channel_info_.channel_id); + mock_queue.consumed_buffer.erase(channel_info_.channel_id); + return StreamingStatus::OK; +} + +StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); + auto &ring_buffer = mock_queue.message_buffer[channel_info_.channel_id]; + if (ring_buffer->Full()) { + return StreamingStatus::OutOfMemory; + } + MockQueueItem item; + item.data.reset(new uint8_t[data_size]); + item.data_size = data_size; + std::memcpy(item.data.get(), data, data_size); + ring_buffer->Push(item); + return StreamingStatus::OK; +} + +StreamingStatus MockProducer::RefreshChannelInfo() { + MockQueue &mock_queue = MockQueue::GetMockQueue(); + channel_info_.queue_info.consumed_message_id = + mock_queue.queue_info_map[channel_info_.channel_id].consumed_message_id; + return StreamingStatus::OK; +} + +StreamingStatus MockConsumer::ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); + auto &channel_id = channel_info_.channel_id; + if (mock_queue.message_buffer.find(channel_id) == mock_queue.message_buffer.end()) { + return StreamingStatus::NoSuchItem; + } + if (mock_queue.message_buffer[channel_id]->Empty()) { + return StreamingStatus::NoSuchItem; + } + MockQueueItem item = mock_queue.message_buffer[channel_id]->Front(); + mock_queue.message_buffer[channel_id]->Pop(); + mock_queue.consumed_buffer[channel_id]->Push(item); + data = item.data.get(); + data_size = item.data_size; + return StreamingStatus::OK; +} + +StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); + auto &channel_id = channel_info_.channel_id; + auto &ring_buffer = mock_queue.consumed_buffer[channel_id]; + while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) { + ring_buffer->Pop(); + } + mock_queue.queue_info_map[channel_id].consumed_message_id = offset_id; + return StreamingStatus::OK; +} + +StreamingStatus MockConsumer::RefreshChannelInfo() { + MockQueue &mock_queue = MockQueue::GetMockQueue(); + channel_info_.queue_info.last_message_id = + mock_queue.queue_info_map[channel_info_.channel_id].last_message_id; + return StreamingStatus::OK; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/channel/channel.h b/streaming/src/channel/channel.h new file mode 100644 index 00000000..582ecfaf --- /dev/null +++ b/streaming/src/channel/channel.h @@ -0,0 +1,214 @@ +#pragma once + +#include "common/status.h" +#include "config/streaming_config.h" +#include "queue/queue_handler.h" +#include "ring_buffer/ring_buffer.h" +#include "util/config.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +using namespace ray::core; + +enum class TransferCreationStatus : uint32_t { + FreshStarted = 0, + PullOk = 1, + Timeout = 2, + DataLost = 3, + Invalid = 999, +}; + +struct StreamingQueueInfo { + uint64_t first_seq_id = 0; + uint64_t last_message_id = 0; + uint64_t target_message_id = 0; + uint64_t consumed_message_id = 0; +}; + +struct ChannelCreationParameter { + ActorID actor_id; + std::shared_ptr async_function; + std::shared_ptr sync_function; +}; + +/// PrducerChannelinfo and ConsumerChannelInfo contains channel information and +/// its metrics that help us to debug or show important messages in logging. +struct ProducerChannelInfo { + ObjectID channel_id; + StreamingRingBufferPtr writer_ring_buffer; + uint64_t current_message_id; + uint64_t message_last_commit_id; + StreamingQueueInfo queue_info; + uint32_t queue_size; + int64_t message_pass_by_ts; + ChannelCreationParameter parameter; + + /// The following parameters are used for event driven to record different + /// input events. + uint64_t sent_empty_cnt = 0; + uint64_t flow_control_cnt = 0; + uint64_t user_event_cnt = 0; + uint64_t rb_full_cnt = 0; + uint64_t queue_full_cnt = 0; + uint64_t in_event_queue_cnt = 0; + bool in_event_queue = false; + bool flow_control = false; +}; + +struct ConsumerChannelInfo { + ObjectID channel_id; + uint64_t current_message_id; + uint64_t barrier_id; + uint64_t partial_barrier_id; + + StreamingQueueInfo queue_info; + + uint64_t last_queue_item_delay = 0; + uint64_t last_queue_item_latency = 0; + uint64_t last_queue_target_diff = 0; + uint64_t get_queue_item_times = 0; + ChannelCreationParameter parameter; + // Total count of notify request. + uint64_t notify_cnt = 0; + uint64_t resend_notify_timer; +}; + +/// Two types of channel are presented: +/// * ProducerChannel is supporting all writing operations for upperlevel. +/// * ConsumerChannel is for all reader operations. +/// They share similar interfaces: +/// * ClearTransferCheckpoint(it's empty and unsupported now, we will add +/// implementation in next PR) +/// * NotifychannelConsumed (notify owner of channel which range data should +// be release to avoid out of memory) +/// but some differences in read/write function.(named ProduceItemTochannel and +/// ConsumeItemFrom channel) +class ProducerChannel { + public: + explicit ProducerChannel(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info); + virtual ~ProducerChannel() = default; + virtual StreamingStatus CreateTransferChannel() = 0; + virtual StreamingStatus DestroyTransferChannel() = 0; + virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) = 0; + virtual StreamingStatus RefreshChannelInfo() = 0; + virtual StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) = 0; + virtual StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) = 0; + + protected: + std::shared_ptr transfer_config_; + ProducerChannelInfo &channel_info_; +}; + +class ConsumerChannel { + public: + explicit ConsumerChannel(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info); + virtual ~ConsumerChannel() = default; + virtual TransferCreationStatus CreateTransferChannel() = 0; + virtual StreamingStatus DestroyTransferChannel() = 0; + virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) = 0; + virtual StreamingStatus RefreshChannelInfo() = 0; + virtual StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) = 0; + virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0; + + protected: + std::shared_ptr transfer_config_; + ConsumerChannelInfo &channel_info_; +}; + +class StreamingQueueProducer : public ProducerChannel { + public: + explicit StreamingQueueProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info); + ~StreamingQueueProducer() override; + StreamingStatus CreateTransferChannel() override; + StreamingStatus DestroyTransferChannel() override; + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override; + StreamingStatus RefreshChannelInfo() override; + StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; + + private: + StreamingStatus CreateQueue(); + Status PushQueueItem(uint8_t *data, uint32_t data_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end); + + private: + std::shared_ptr queue_; +}; + +class StreamingQueueConsumer : public ConsumerChannel { + public: + explicit StreamingQueueConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info); + ~StreamingQueueConsumer() override; + TransferCreationStatus CreateTransferChannel() override; + StreamingStatus DestroyTransferChannel() override; + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override; + StreamingStatus RefreshChannelInfo() override; + StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; + + private: + StreamingQueueStatus GetQueue(const ObjectID &queue_id, uint64_t start_msg_id, + const ChannelCreationParameter &init_param); + + private: + std::shared_ptr queue_; +}; + +/// MockProducer and Mockconsumer are independent implementation of channels that +/// conduct a very simple memory channel for unit tests or intergation test. +class MockProducer : public ProducerChannel { + public: + explicit MockProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &channel_info) + : ProducerChannel(transfer_config, channel_info){}; + StreamingStatus CreateTransferChannel() override; + + StreamingStatus DestroyTransferChannel() override; + + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override { + return StreamingStatus::OK; + } + + StreamingStatus RefreshChannelInfo() override; + + StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override; + + StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) override { + return StreamingStatus::OK; + } +}; + +class MockConsumer : public ConsumerChannel { + public: + explicit MockConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : ConsumerChannel(transfer_config, c_channel_info){}; + TransferCreationStatus CreateTransferChannel() override { + return TransferCreationStatus::PullOk; + } + StreamingStatus DestroyTransferChannel() override { return StreamingStatus::OK; } + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override { + return StreamingStatus::OK; + } + StreamingStatus RefreshChannelInfo() override; + StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/common/status.h b/streaming/src/common/status.h new file mode 100644 index 00000000..63a1cbae --- /dev/null +++ b/streaming/src/common/status.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +namespace ray { +namespace streaming { + +enum class StreamingStatus : uint32_t { + OK = 0, + ReconstructTimeOut = 1, + QueueIdNotFound = 3, + ResubscribeFailed = 4, + EmptyRingBuffer = 5, + FullChannel = 6, + NoSuchItem = 7, + InitQueueFailed = 8, + GetBundleTimeOut = 9, + SkipSendEmptyMessage = 10, + Interrupted = 11, + OutOfMemory = 13, + Invalid = 14, + UnknownError = 15, + TailStatus = 999, + MIN = OK, + MAX = TailStatus +}; + +static inline std::ostream &operator<<(std::ostream &os, const StreamingStatus &status) { + os << static_cast::type>(status); + return os; +} + +#define RETURN_IF_NOT_OK(STATUS_EXP) \ + { \ + StreamingStatus state = STATUS_EXP; \ + if (StreamingStatus::OK != state) { \ + return state; \ + } \ + } + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/config/streaming_config.cc b/streaming/src/config/streaming_config.cc new file mode 100644 index 00000000..7dc94c86 --- /dev/null +++ b/streaming/src/config/streaming_config.cc @@ -0,0 +1,55 @@ +#include "config/streaming_config.h" + +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +uint64_t StreamingConfig::TIME_WAIT_UINT = 1; +uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500; +uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20; +// Time to force clean if barrier in queue, default 0ms +const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048; +const uint32_t StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL = 1000; // ms + +#define RESET_IF_INT_CONF(KEY, VALUE) \ + if (0 != VALUE) { \ + Set##KEY(VALUE); \ + } +#define RESET_IF_STR_CONF(KEY, VALUE) \ + if (!VALUE.empty()) { \ + Set##KEY(VALUE); \ + } +#define RESET_IF_NOT_DEFAULT_CONF(KEY, VALUE, DEFAULT) \ + if (DEFAULT != VALUE) { \ + Set##KEY(VALUE); \ + } + +void StreamingConfig::FromProto(const uint8_t *data, uint32_t size) { + proto::StreamingConfig config; + STREAMING_CHECK(config.ParseFromArray(data, size)) << "Parse streaming conf failed"; + RESET_IF_STR_CONF(JobName, config.job_name()) + RESET_IF_STR_CONF(WorkerName, config.worker_name()) + RESET_IF_STR_CONF(OpName, config.op_name()) + RESET_IF_NOT_DEFAULT_CONF(NodeType, config.role(), proto::NodeType::UNKNOWN) + RESET_IF_INT_CONF(RingBufferCapacity, config.ring_buffer_capacity()) + RESET_IF_INT_CONF(EmptyMessageTimeInterval, config.empty_message_interval()) + RESET_IF_NOT_DEFAULT_CONF(FlowControlType, config.flow_control_type(), + proto::FlowControlType::UNKNOWN_FLOW_CONTROL_TYPE) + RESET_IF_INT_CONF(WriterConsumedStep, config.writer_consumed_step()) + RESET_IF_INT_CONF(ReaderConsumedStep, config.reader_consumed_step()) + RESET_IF_INT_CONF(EventDrivenFlowControlInterval, + config.event_driven_flow_control_interval()) + STREAMING_CHECK(writer_consumed_step_ >= reader_consumed_step_) + << "Writer consuemd step " << writer_consumed_step_ + << "can not be smaller then reader consumed step " << reader_consumed_step_; +} + +uint32_t StreamingConfig::GetRingBufferCapacity() const { return ring_buffer_capacity_; } + +void StreamingConfig::SetRingBufferCapacity(uint32_t ring_buffer_capacity) { + StreamingConfig::ring_buffer_capacity_ = + std::min(ring_buffer_capacity, StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/config/streaming_config.h b/streaming/src/config/streaming_config.h new file mode 100644 index 00000000..1485e8b8 --- /dev/null +++ b/streaming/src/config/streaming_config.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include + +#include "protobuf/streaming.pb.h" +#include "ray/common/id.h" + +namespace ray { +namespace streaming { + +using ReliabilityLevel = proto::ReliabilityLevel; +using StreamingRole = proto::NodeType; + +#define DECL_GET_SET_PROPERTY(TYPE, NAME, VALUE) \ + TYPE Get##NAME() const { return VALUE; } \ + void Set##NAME(TYPE value) { VALUE = value; } + +using TagsMap = std::unordered_map; +class StreamingMetricsConfig { + public: + DECL_GET_SET_PROPERTY(const std::string &, MetricsServiceName, metrics_service_name_); + DECL_GET_SET_PROPERTY(uint32_t, MetricsReportInterval, metrics_report_interval_); + DECL_GET_SET_PROPERTY(const TagsMap, MetricsGlobalTags, global_tags); + + private: + std::string metrics_service_name_ = "streaming"; + uint32_t metrics_report_interval_ = 10; + std::unordered_map global_tags; +}; + +class StreamingConfig { + public: + static uint64_t TIME_WAIT_UINT; + static uint32_t DEFAULT_RING_BUFFER_CAPACITY; + static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; + static const uint32_t MESSAGE_BUNDLE_MAX_SIZE; + static const uint32_t RESEND_NOTIFY_MAX_INTERVAL; + + private: + uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY; + + uint32_t empty_message_time_interval_ = DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; + + streaming::proto::NodeType node_type_ = streaming::proto::NodeType::TRANSFORM; + + std::string job_name_ = "DEFAULT_JOB_NAME"; + + std::string op_name_ = "DEFAULT_OP_NAME"; + + std::string worker_name_ = "DEFAULT_WORKER_NAME"; + + // Default flow control type is unconsumed sequence flow control. More detail + // introducation and implemention in ray/streaming/src/flow_control.h. + streaming::proto::FlowControlType flow_control_type_ = + streaming::proto::FlowControlType::UnconsumedSeqFlowControl; + + // Default writer and reader consumed step. + uint32_t writer_consumed_step_ = 1000; + uint32_t reader_consumed_step_ = 100; + + uint32_t event_driven_flow_control_interval_ = 1; + + ReliabilityLevel streaming_strategy_ = ReliabilityLevel::EXACTLY_ONCE; + StreamingRole streaming_role = StreamingRole::TRANSFORM; + bool metrics_enable = true; + + public: + void FromProto(const uint8_t *, uint32_t size); + + inline bool IsAtLeastOnce() const { + return ReliabilityLevel::AT_LEAST_ONCE == streaming_strategy_; + } + inline bool IsExactlyOnce() const { + return ReliabilityLevel::EXACTLY_ONCE == streaming_strategy_; + } + + DECL_GET_SET_PROPERTY(const std::string &, WorkerName, worker_name_) + DECL_GET_SET_PROPERTY(const std::string &, OpName, op_name_) + DECL_GET_SET_PROPERTY(uint32_t, EmptyMessageTimeInterval, empty_message_time_interval_) + DECL_GET_SET_PROPERTY(streaming::proto::NodeType, NodeType, node_type_) + DECL_GET_SET_PROPERTY(const std::string &, JobName, job_name_) + DECL_GET_SET_PROPERTY(uint32_t, WriterConsumedStep, writer_consumed_step_) + DECL_GET_SET_PROPERTY(uint32_t, ReaderConsumedStep, reader_consumed_step_) + DECL_GET_SET_PROPERTY(streaming::proto::FlowControlType, FlowControlType, + flow_control_type_) + DECL_GET_SET_PROPERTY(uint32_t, EventDrivenFlowControlInterval, + event_driven_flow_control_interval_) + DECL_GET_SET_PROPERTY(StreamingRole, StreamingRole, streaming_role) + DECL_GET_SET_PROPERTY(ReliabilityLevel, ReliabilityLevel, streaming_strategy_) + DECL_GET_SET_PROPERTY(bool, MetricsEnable, metrics_enable) + + uint32_t GetRingBufferCapacity() const; + /// Note(lingxuan.zlx), RingBufferCapacity's valid range is from 1 to + /// MESSAGE_BUNDLE_MAX_SIZE, so we don't use DECL_GET_SET_PROPERTY for it. + void SetRingBufferCapacity(uint32_t ring_buffer_capacity); +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc new file mode 100644 index 00000000..2d481fab --- /dev/null +++ b/streaming/src/data_reader.cc @@ -0,0 +1,502 @@ +#include "data_reader.h" + +#include +#include +#include +#include +#include +#include + +#include "message/message_bundle.h" +#include "ray/util/logging.h" +#include "ray/util/util.h" + +namespace ray { +namespace streaming { + +const uint32_t DataReader::kReadItemTimeout = 1000; + +void DataReader::Init(const std::vector &input_ids, + const std::vector &init_params, + const std::vector &streaming_msg_ids, + std::vector &creation_status, + int64_t timer_interval) { + Init(input_ids, init_params, timer_interval); + for (size_t i = 0; i < input_ids.size(); ++i) { + auto &q_id = input_ids[i]; + last_message_id_[q_id] = streaming_msg_ids[i]; + channel_info_map_[q_id].current_message_id = streaming_msg_ids[i]; + } + InitChannel(creation_status); +} + +void DataReader::Init(const std::vector &input_ids, + const std::vector &init_params, + int64_t timer_interval) { + STREAMING_LOG(INFO) << input_ids.size() << " queue to init."; + + transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids); + + last_fetched_queue_item_ = nullptr; + timer_interval_ = timer_interval; + last_message_ts_ = 0; + input_queue_ids_ = input_ids; + last_message_latency_ = 0; + last_bundle_unit_ = 0; + + for (size_t i = 0; i < input_ids.size(); ++i) { + ObjectID q_id = input_ids[i]; + STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id; + auto &channel_info = channel_info_map_[q_id]; + channel_info.channel_id = q_id; + channel_info.parameter = init_params[i]; + channel_info.last_queue_item_delay = 0; + channel_info.last_queue_item_latency = 0; + channel_info.last_queue_target_diff = 0; + channel_info.get_queue_item_times = 0; + channel_info.resend_notify_timer = 0; + } + + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, nullptr, this); + + /// Make the input id location stable. + sort(input_queue_ids_.begin(), input_queue_ids_.end(), + [](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); }); + std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_)); +} + +StreamingStatus DataReader::InitChannel( + std::vector &creation_status) { + STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num " + << input_queue_ids_.size() + << ", unready queue num=" << unready_queue_ids_.size(); + + for (const auto &input_channel : unready_queue_ids_) { + auto &channel_info = channel_info_map_[input_channel]; + std::shared_ptr channel; + if (runtime_context_->IsMockTest()) { + channel = std::make_shared(transfer_config_, channel_info); + } else { + channel = std::make_shared(transfer_config_, channel_info); + } + + channel_map_.emplace(input_channel, channel); + TransferCreationStatus status = channel->CreateTransferChannel(); + creation_status.push_back(status); + if (TransferCreationStatus::PullOk != status) { + STREAMING_LOG(ERROR) << "Initialize queue failed, id=" << input_channel + << ", status=" << static_cast(status); + } + } + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + STREAMING_LOG(INFO) << "[Reader] Reader construction done!"; + return StreamingStatus::OK; +} + +StreamingStatus DataReader::InitChannelMerger(uint32_t timeout_ms) { + STREAMING_LOG(INFO) << "[Reader] Initializing queue merger."; + // Init reader merger by given comparator when it's first created. + StreamingReaderMsgPtrComparator comparator( + runtime_context_->GetConfig().GetReliabilityLevel()); + if (!reader_merger_) { + reader_merger_.reset( + new PriorityQueue, StreamingReaderMsgPtrComparator>( + comparator)); + } + + // An old item in merger vector must be evicted before new queue item has been + // pushed. + if (!unready_queue_ids_.empty() && last_fetched_queue_item_) { + STREAMING_LOG(INFO) << "pop old item from=" << last_fetched_queue_item_->from; + RETURN_IF_NOT_OK(StashNextMessageAndPop(last_fetched_queue_item_, timeout_ms)) + last_fetched_queue_item_.reset(); + } + // Create initial heap for priority queue. + std::vector unready_queue_ids_stashed; + for (auto &input_queue : unready_queue_ids_) { + std::shared_ptr msg = std::make_shared(); + auto status = GetMessageFromChannel(channel_info_map_[input_queue], msg, timeout_ms, + timeout_ms); + if (StreamingStatus::OK != status) { + STREAMING_LOG(INFO) + << "[Reader] initializing merger, get message from channel timeout, " + << input_queue << ", status => " << static_cast(status); + unready_queue_ids_stashed.push_back(input_queue); + continue; + } + channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId(); + reader_merger_->push(msg); + } + if (unready_queue_ids_stashed.empty()) { + STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; + return StreamingStatus::OK; + } else { + STREAMING_LOG(INFO) << "[Reader] Initializing merger unfinished."; + unready_queue_ids_ = unready_queue_ids_stashed; + return StreamingStatus::GetBundleTimeOut; + } +} + +StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info, + std::shared_ptr &message, + uint32_t timeout_ms, + uint32_t wait_time_ms) { + auto &qid = channel_info.channel_id; + message->from = qid; + last_read_q_id_ = qid; + + bool is_valid_bundle = false; + int64_t start_time = current_sys_time_ms(); + STREAMING_LOG(DEBUG) << "GetMessageFromChannel, timeout_ms=" << timeout_ms + << ", wait_time_ms=" << wait_time_ms; + while (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running && + !is_valid_bundle && current_sys_time_ms() - start_time < timeout_ms) { + STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id=" << qid; + /// In AT_LEAST_ONCE, wait_time_ms is set to 0, means `ConsumeItemFromChannel` + /// will return immediately if no items in queue. At the same time, `timeout_ms` is + /// ignored. + channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( + message->data, message->data_size, wait_time_ms); + + STREAMING_LOG(DEBUG) << "ConsumeItemFromChannel done, bytes=" + << Util::Byte2hex(message->data, message->data_size); + + channel_info.get_queue_item_times++; + if (!message->data) { + RETURN_IF_NOT_OK(reliability_helper_->HandleNoValidItem(channel_info)); + } else { + uint64_t current_time = current_sys_time_ms(); + channel_info.resend_notify_timer = current_time; + // Note(lingxuan.zlx): To find which channel get an invalid data and + // print channel id for debugging. + STREAMING_CHECK(StreamingMessageBundleMeta::CheckBundleMagicNum(message->data)) + << "Magic number invalid, from channel " << channel_info.channel_id; + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); + + is_valid_bundle = true; + if (!runtime_context_->GetConfig().IsAtLeastOnce()) { + // filter message when msg_id doesn't match. + // reader will filter message only when using streaming queue and + // non-at-least-once mode + BundleCheckStatus status = CheckBundle(message); + STREAMING_LOG(DEBUG) << "CheckBundle, result=" << status + << ", last_msg_id=" << last_message_id_[message->from]; + if (status == BundleCheckStatus::BundleToBeSplit) { + SplitBundle(message, last_message_id_[qid]); + } + if (status == BundleCheckStatus::BundleToBeThrown && message->meta->IsBarrier()) { + STREAMING_LOG(WARNING) + << "Throw barrier, msg_id=" << message->meta->GetLastMessageId(); + } + is_valid_bundle = status != BundleCheckStatus::BundleToBeThrown; + } + } + } + if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { + return StreamingStatus::Interrupted; + } + + if (!is_valid_bundle) { + STREAMING_LOG(DEBUG) << "GetMessageFromChannel timeout, qid=" + << channel_info.channel_id; + return StreamingStatus::GetBundleTimeOut; + } + + STREAMING_LOG(DEBUG) << "[Reader] received message id=" + << message->meta->GetLastMessageId() << ", queue id=" << qid; + last_message_id_[message->from] = message->meta->GetLastMessageId(); + return StreamingStatus::OK; +} + +BundleCheckStatus DataReader::CheckBundle(const std::shared_ptr &message) { + uint64_t end_msg_id = message->meta->GetLastMessageId(); + uint64_t start_msg_id = message->meta->IsEmptyMsg() + ? end_msg_id + : end_msg_id - message->meta->GetMessageListSize() + 1; + uint64_t last_msg_id = last_message_id_[message->from]; + + // Writer will keep sending bundles when downstream reader failover. After reader + // recovered, it will receive these bundles whoes msg_id is larger than expected. + if (start_msg_id > last_msg_id + 1) { + return BundleCheckStatus::BundleToBeThrown; + } + if (end_msg_id < last_msg_id + 1) { + // Empty message and barrier's msg_id equals to last message, so we shouldn't throw + // them. + return end_msg_id == last_msg_id && !message->meta->IsBundle() + ? BundleCheckStatus::OkBundle + : BundleCheckStatus::BundleToBeThrown; + } + // Normal bundles. + if (start_msg_id == last_msg_id + 1) { + return BundleCheckStatus::OkBundle; + } + return BundleCheckStatus::BundleToBeSplit; +} + +void DataReader::SplitBundle(std::shared_ptr &message, uint64_t last_msg_id) { + std::list msg_list; + StreamingMessageBundle::GetMessageListFromRawData( + message->data + kMessageBundleHeaderSize, + message->data_size - kMessageBundleHeaderSize, message->meta->GetMessageListSize(), + msg_list); + uint32_t bundle_size = 0; + for (auto it = msg_list.begin(); it != msg_list.end();) { + if ((*it)->GetMessageId() > last_msg_id) { + bundle_size += (*it)->ClassBytesSize(); + it++; + } else { + it = msg_list.erase(it); + } + } + STREAMING_LOG(DEBUG) << "Split message, from_queue_id=" << message->from + << ", start_msg_id=" << msg_list.front()->GetMessageId() + << ", end_msg_id=" << msg_list.back()->GetMessageId(); + // recreate bundle + auto cut_msg_bundle = std::make_shared( + msg_list, message->meta->GetMessageBundleTs(), msg_list.back()->GetMessageId(), + StreamingMessageBundleType::Bundle, bundle_size); + message->Realloc(cut_msg_bundle->ClassBytesSize()); + cut_msg_bundle->ToBytes(message->data); + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); +} + +StreamingStatus DataReader::StashNextMessageAndPop(std::shared_ptr &message, + uint32_t timeout_ms) { + STREAMING_LOG(DEBUG) << "StashNextMessageAndPop, timeout_ms=" << timeout_ms; + + // Get the first message. + message = reader_merger_->top(); + STREAMING_LOG(DEBUG) << "Messages to be popped=" << *message + << ", merger size=" << reader_merger_->size() + << ", bytes=" << Util::Byte2hex(message->data, message->data_size); + + // Then stash next message from its from queue. + std::shared_ptr new_msg = std::make_shared(); + auto &channel_info = channel_info_map_[message->from]; + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg, timeout_ms, timeout_ms)) + new_msg->last_barrier_id = channel_info.barrier_id; + reader_merger_->push(new_msg); + STREAMING_LOG(DEBUG) << "New message pushed=" << *new_msg + << ", merger size=" << reader_merger_->size() + << ", bytes=" << Util::Byte2hex(new_msg->data, new_msg->data_size); + // Barrier's message ID is equal to last message's ID. + // We will mark last message's ID as consumed in GetBundle. + // So barrier might be erased by streaming queue. We make a hack way here to + // copy barrier's data from streaming queue. TODO: There should be a more elegant way to + // do this. + if (new_msg->meta->IsBarrier()) { + uint8_t *origin_data = new_msg->data; + new_msg->Realloc(new_msg->data_size); + memcpy(new_msg->data, origin_data, new_msg->data_size); + } + + // Pop message. + reader_merger_->pop(); + STREAMING_LOG(DEBUG) << "Message popped, msg=" << *message + << ", bytes=" << Util::Byte2hex(message->data, message->data_size); + + // Record some metrics. + channel_info.last_queue_item_delay = + new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs(); + channel_info.last_queue_item_latency = current_time_ms() - current_time_ms(); + return StreamingStatus::OK; +} + +StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr &message, + bool &is_valid_break, + uint32_t timeout_ms) { + RETURN_IF_NOT_OK(StashNextMessageAndPop(message, timeout_ms)) + + auto &offset_info = channel_info_map_[message->from]; + uint64_t cur_queue_previous_msg_id = offset_info.current_message_id; + STREAMING_LOG(DEBUG) << "[Reader] [Bundle]" << *message + << ", cur_queue_previous_msg_id=" << cur_queue_previous_msg_id; + int64_t cur_time = current_time_ms(); + if (message->meta->IsBundle()) { + last_message_ts_ = cur_time; + is_valid_break = true; + } else if (message->meta->IsBarrier() && BarrierAlign(message)) { + last_message_ts_ = cur_time; + is_valid_break = true; + } else if (timer_interval_ != -1 && cur_time - last_message_ts_ >= timer_interval_ && + message->meta->IsEmptyMsg()) { + // Sent empty message when reaching timer_interval + last_message_ts_ = cur_time; + is_valid_break = true; + } + + offset_info.current_message_id = message->meta->GetLastMessageId(); + last_bundle_ts_ = message->meta->GetMessageBundleTs(); + + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] Get merged message bundle=" << *message + << ", is_valid_break=" << is_valid_break; + last_fetched_queue_item_ = message; + return StreamingStatus::OK; +} + +bool DataReader::BarrierAlign(std::shared_ptr &message) { + // Arrange barrier action when barrier is arriving. + StreamingBarrierHeader barrier_header; + StreamingMessage::GetBarrierIdFromRawData(message->data + kMessageHeaderSize, + &barrier_header); + uint64_t barrier_id = barrier_header.barrier_id; + auto *barrier_align_cnt = &global_barrier_cnt_; + auto &channel_info = channel_info_map_[message->from]; + // Target count is input vector size (global barrier). + uint32_t target_count = 0; + + channel_info.barrier_id = barrier_header.barrier_id; + target_count = input_queue_ids_.size(); + (*barrier_align_cnt)[barrier_id]++; + // The next message checkpoint is changed if this's barrier message. + STREAMING_LOG(INFO) << "[Reader] [Barrier] get barrier, barrier_id=" << barrier_id + << ", barrier_cnt=" << (*barrier_align_cnt)[barrier_id] + << ", global barrier id=" << barrier_header.barrier_id + << ", from q_id=" << message->from << ", barrier type=" + << static_cast(barrier_header.barrier_type) + << ", target count=" << target_count; + // Notify invoker the last barrier, so that checkpoint or something related can be + // taken right now. + if ((*barrier_align_cnt)[barrier_id] == target_count) { + // map can't be used in multithread (crash in report timer) + barrier_align_cnt->erase(barrier_id); + STREAMING_LOG(INFO) + << "[Reader] [Barrier] last barrier received, return barrier. barrier_id = " + << barrier_id << ", from q_id=" << message->from; + return true; + } + return false; +} + +StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, + std::shared_ptr &message) { + STREAMING_LOG(DEBUG) << "GetBundle, timeout_ms=" << timeout_ms; + // Notify upstream that last fetched item has been consumed. + if (last_fetched_queue_item_) { + NotifyConsumed(last_fetched_queue_item_); + } + + /// DataBundle will be returned to the upper layer in the following cases: + /// a batch of data is returned when the real data is read, or an empty message + /// is returned to the upper layer when the given timeout period is reached to + /// avoid blocking for too long. + auto start_time = current_time_ms(); + bool is_valid_break = false; + uint32_t empty_bundle_cnt = 0; + while (!is_valid_break) { + if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { + return StreamingStatus::Interrupted; + } + auto cur_time = current_time_ms(); + auto dur = cur_time - start_time; + if (dur > timeout_ms) { + return StreamingStatus::GetBundleTimeOut; + } + if (!unready_queue_ids_.empty()) { + std::vector creation_status; + StreamingStatus status = InitChannel(creation_status); + switch (status) { + case StreamingStatus::InitQueueFailed: + break; + default: + STREAMING_LOG(INFO) << "Init reader queue in GetBundle"; + } + if (StreamingStatus::OK != status) { + return status; + } + RETURN_IF_NOT_OK(InitChannelMerger(timeout_ms)) + unready_queue_ids_.clear(); + auto &merge_vec = reader_merger_->getRawVector(); + for (auto &bundle : merge_vec) { + STREAMING_LOG(INFO) << "merger vector item=" << bundle->from; + } + } + + RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break, timeout_ms)); + if (!is_valid_break) { + empty_bundle_cnt++; + NotifyConsumed(message); + } + } + last_message_latency_ += current_time_ms() - start_time; + if (message->meta->GetMessageListSize() > 0) { + last_bundle_unit_ = message->data_size * 1.0 / message->meta->GetMessageListSize(); + } + return StreamingStatus::OK; +} + +void DataReader::GetOffsetInfo( + std::unordered_map *&offset_map) { + offset_map = &channel_info_map_; + for (auto &offset_info : channel_info_map_) { + STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first + << ", message id=" << offset_info.second.current_message_id; + } +} + +void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) { + STREAMING_LOG(DEBUG) << "NotifyConsumedItem, offset=" << offset + << ", channel_id=" << channel_info.channel_id; + channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); +} + +DataReader::DataReader(std::shared_ptr &runtime_context) + : transfer_config_(new Config()), runtime_context_(runtime_context) {} + +DataReader::~DataReader() { STREAMING_LOG(INFO) << "Streaming reader deconstruct."; } + +void DataReader::Stop() { + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); +} + +void DataReader::NotifyConsumed(std::shared_ptr &message) { + auto &channel_info = channel_info_map_[message->from]; + auto &queue_info = channel_info.queue_info; + channel_info.notify_cnt++; + if (queue_info.target_message_id <= message->meta->GetLastMessageId()) { + NotifyConsumedItem(channel_info, message->meta->GetLastMessageId()); + + channel_map_[channel_info.channel_id]->RefreshChannelInfo(); + if (queue_info.last_message_id != QUEUE_INVALID_SEQ_ID) { + uint64_t original_target_message_id = queue_info.target_message_id; + queue_info.target_message_id = + std::min(queue_info.last_message_id, + message->meta->GetLastMessageId() + + runtime_context_->GetConfig().GetReaderConsumedStep()); + channel_info.last_queue_target_diff = + queue_info.target_message_id - original_target_message_id; + } else { + STREAMING_LOG(WARNING) << "[Reader] [QueueInfo] channel id " << message->from + << ", last message id " << queue_info.last_message_id; + } + STREAMING_LOG(DEBUG) << "[Reader] [Consumed] Trigger notify consumed" + << ", channel id=" << message->from + << ", last message id=" << queue_info.last_message_id + << ", target message id=" << queue_info.target_message_id + << ", consumed message id=" << message->meta->GetLastMessageId() + << ", bundle type=" + << static_cast(message->meta->GetBundleType()) + << ", last message bundle ts=" + << message->meta->GetMessageBundleTs(); + } +} + +bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr &a, + const std::shared_ptr &b) { + if (comp_strategy == ReliabilityLevel::EXACTLY_ONCE) { + if (a->last_barrier_id != b->last_barrier_id) + return a->last_barrier_id > b->last_barrier_id; + } + STREAMING_CHECK(a->meta); + // We proposed fixed id sequnce for stability of message in sorting. + if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) { + return a->from.Hash() > b->from.Hash(); + } + return a->meta->GetMessageBundleTs() > b->meta->GetMessageBundleTs(); +} + +} // namespace streaming + +} // namespace ray diff --git a/streaming/src/data_reader.h b/streaming/src/data_reader.h new file mode 100644 index 00000000..615047ec --- /dev/null +++ b/streaming/src/data_reader.h @@ -0,0 +1,161 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "channel/channel.h" +#include "message/message_bundle.h" +#include "message/priority_queue.h" +#include "reliability/barrier_helper.h" +#include "reliability_helper.h" +#include "runtime_context.h" + +namespace ray { +namespace streaming { + +class ReliabilityHelper; +class AtLeastOnceHelper; + +enum class BundleCheckStatus : uint32_t { + OkBundle = 0, + BundleToBeThrown = 1, + BundleToBeSplit = 2 +}; + +static inline std::ostream &operator<<(std::ostream &os, + const BundleCheckStatus &status) { + os << static_cast::type>(status); + return os; +} + +/// This is implementation of merger policy in StreamingReaderMsgPtrComparator. +struct StreamingReaderMsgPtrComparator { + explicit StreamingReaderMsgPtrComparator(ReliabilityLevel strategy) + : comp_strategy(strategy){}; + StreamingReaderMsgPtrComparator(){}; + ReliabilityLevel comp_strategy = ReliabilityLevel::EXACTLY_ONCE; + + bool operator()(const std::shared_ptr &a, + const std::shared_ptr &b); +}; + +/// DataReader will fetch data bundles from channels of upstream workers, once +/// invoked by user thread. Firstly put them into a priority queue ordered by bundle +/// comparator that's related meta-data, then pop out the top bunlde to user +/// thread every time, so that the order of the message can be guranteed, which +/// will also facilitate our future implementation of fault tolerance. Finally +/// user thread can extract messages from the bundle and process one by one. +class DataReader { + private: + std::vector input_queue_ids_; + + std::vector unready_queue_ids_; + + std::unique_ptr< + PriorityQueue, StreamingReaderMsgPtrComparator>> + reader_merger_; + + std::shared_ptr last_fetched_queue_item_; + + std::unordered_map global_barrier_cnt_; + + int64_t timer_interval_; + int64_t last_bundle_ts_; + int64_t last_message_ts_; + int64_t last_message_latency_; + int64_t last_bundle_unit_; + + ObjectID last_read_q_id_; + + static const uint32_t kReadItemTimeout; + StreamingBarrierHelper barrier_helper_; + std::shared_ptr reliability_helper_; + std::unordered_map last_message_id_; + + friend class ReliabilityHelper; + friend class AtLeastOnceHelper; + + protected: + std::unordered_map channel_info_map_; + std::unordered_map> channel_map_; + std::shared_ptr transfer_config_; + std::shared_ptr runtime_context_; + + public: + explicit DataReader(std::shared_ptr &runtime_context); + virtual ~DataReader(); + + /// During initialization, only the channel parameters and necessary member properties + /// are assigned. All channels will be connected in the first reading operation. + /// \param input_ids + /// \param init_params + /// \param msg_ids + /// \param[out] creation_status + /// \param timer_interval + void Init(const std::vector &input_ids, + const std::vector &init_params, + const std::vector &msg_ids, + std::vector &creation_status, int64_t timer_interval); + + /// Create reader use msg_id=0, this method is public only for test, and users + /// usuallly don't need it. + /// \param input_ids + /// \param init_params + /// \param timer_interval + void Init(const std::vector &input_ids, + const std::vector &init_params, + int64_t timer_interval); + + /// Get latest message from input queues. + /// \param timeout_ms + /// \param message, return the latest message + StreamingStatus GetBundle(uint32_t timeout_ms, std::shared_ptr &message); + + /// Get offset information about channels for checkpoint. + /// \param offset_map (return value) + void GetOffsetInfo(std::unordered_map *&offset_map); + + void Stop(); + + /// Notify input queues to clear data whose seq id is equal or less than offset. + /// It's used when checkpoint is done. + /// \param channel_info consumer's channel info + /// \param offset consumed channel offset + void NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset); + + //// Notify message related channel to clear data. + void NotifyConsumed(std::shared_ptr &message); + + private: + /// Create channels and connect to all upstream. + StreamingStatus InitChannel(std::vector &creation_status); + + /// One item from every channel will be popped out, then collecting + /// them to a merged queue. High prioprity items will be fetched one by one. + /// When item pop from one channel where must produce new item for placeholder + /// in merged queue. + StreamingStatus InitChannelMerger(uint32_t timeout_ms); + + StreamingStatus StashNextMessageAndPop(std::shared_ptr &message, + uint32_t timeout_ms); + + StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info, + std::shared_ptr &message, + uint32_t timeout_ms, uint32_t wait_time_ms); + + /// Get top item from prioprity queue. + StreamingStatus GetMergedMessageBundle(std::shared_ptr &message, + bool &is_valid_break, uint32_t timeout_ms); + + bool BarrierAlign(std::shared_ptr &message); + + BundleCheckStatus CheckBundle(const std::shared_ptr &message); + + static void SplitBundle(std::shared_ptr &message, uint64_t last_msg_id); +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/data_writer.cc b/streaming/src/data_writer.cc new file mode 100644 index 00000000..ff0496cc --- /dev/null +++ b/streaming/src/data_writer.cc @@ -0,0 +1,609 @@ +#include "data_writer.h" + +#include +#include +#include +#include +#include + +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +StreamingStatus DataWriter::WriteChannelProcess(ProducerChannelInfo &channel_info, + bool *is_empty_message) { + // No message in buffer, empty message will be sent to downstream queue. + uint64_t buffer_remain = 0; + StreamingStatus write_queue_flag = WriteBufferToChannel(channel_info, buffer_remain); + int64_t current_ts = current_time_ms(); + if (write_queue_flag == StreamingStatus::EmptyRingBuffer && + current_ts - channel_info.message_pass_by_ts >= + runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) { + write_queue_flag = WriteEmptyMessage(channel_info); + *is_empty_message = true; + STREAMING_LOG(DEBUG) << "send empty message bundle in q_id =>" + << channel_info.channel_id; + } + return write_queue_flag; +} + +StreamingStatus DataWriter::WriteBufferToChannel(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + if (!IsMessageAvailableInBuffer(channel_info)) { + return StreamingStatus::EmptyRingBuffer; + } + + // Flush transient buffer to queue first. + if (buffer_ptr->IsTransientAvaliable()) { + return WriteTransientBufferToChannel(channel_info); + } + + STREAMING_CHECK(CollectFromRingBuffer(channel_info, buffer_remain)) + << "empty data in ringbuffer, q id => " << channel_info.channel_id; + + return WriteTransientBufferToChannel(channel_info); +} + +void DataWriter::Run() { + STREAMING_LOG(INFO) << "Event server start"; + event_service_->Run(); + // Enable empty message timer after writer running. + empty_message_thread_ = + std::make_shared(&DataWriter::EmptyMessageTimerCallback, this); + flow_control_thread_ = + std::make_shared(&DataWriter::FlowControlTimer, this); +} + +/// Since every memory ring buffer's size is limited, when the writing buffer is +/// full, the user thread will be blocked, which will cause backpressure +/// naturally. +uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *data, + uint32_t data_size, + StreamingMessageType message_type) { + // TODO(lingxuan.zlx): currently, unsafe in multithreads + ProducerChannelInfo &channel_info = channel_info_map_[q_id]; + // Write message id stands for current lastest message id and differs from + // channel.current_message_id if it's barrier message. + uint64_t &write_message_id = channel_info.current_message_id; + if (message_type == StreamingMessageType::Message) { + write_message_id++; + } + + STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id + << " data_size: " << data_size + << ", message_type=" << static_cast(message_type) + << ", data=" << Util::Byte2hex(data, data_size) + << ", current_message_id=" << write_message_id; + + auto &ring_buffer_ptr = channel_info.writer_ring_buffer; + while (ring_buffer_ptr->IsFull() && + runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running) { + std::this_thread::sleep_for( + std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT)); + } + if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) { + STREAMING_LOG(WARNING) << "stop in write message to ringbuffer"; + return 0; + } + ring_buffer_ptr->Push(std::make_shared( + data, data_size, write_message_id, message_type)); + + if (ring_buffer_ptr->Size() == 1) { + if (channel_info.in_event_queue) { + ++channel_info.in_event_queue_cnt; + STREAMING_LOG(DEBUG) << "user_event had been in event_queue"; + } else if (!channel_info.flow_control) { + channel_info.in_event_queue = true; + Event event(&channel_info, EventType::UserEvent, false); + event_service_->Push(event); + ++channel_info.user_event_cnt; + } + } + + return write_message_id; +} + +StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, + const ChannelCreationParameter ¶m, + uint64_t channel_message_id, + uint64_t queue_size) { + ProducerChannelInfo &channel_info = channel_info_map_[q_id]; + channel_info.current_message_id = channel_message_id; + channel_info.channel_id = q_id; + channel_info.parameter = param; + channel_info.queue_size = queue_size; + STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]"; + channel_info.writer_ring_buffer = std::make_shared( + runtime_context_->GetConfig().GetRingBufferCapacity(), + StreamingRingBufferType::SPSC); + channel_info.message_pass_by_ts = current_time_ms(); + std::shared_ptr channel; + + if (runtime_context_->IsMockTest()) { + channel = std::make_shared(transfer_config_, channel_info); + } else { + channel = std::make_shared(transfer_config_, channel_info); + } + + channel_map_.emplace(q_id, channel); + RETURN_IF_NOT_OK(channel->CreateTransferChannel()) + return StreamingStatus::OK; +} + +StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, + const std::vector &init_params, + const std::vector &channel_message_id_vec, + const std::vector &queue_size_vec) { + STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty()); + STREAMING_LOG(INFO) << "Job name => " << runtime_context_->GetConfig().GetJobName(); + + output_queue_ids_ = queue_id_vec; + transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec); + + for (size_t i = 0; i < queue_id_vec.size(); ++i) { + StreamingStatus status = InitChannel(queue_id_vec[i], init_params[i], + channel_message_id_vec[i], queue_size_vec[i]); + if (status != StreamingStatus::OK) { + return status; + } + } + + switch (runtime_context_->GetConfig().GetFlowControlType()) { + case proto::FlowControlType::UnconsumedSeqFlowControl: + flow_controller_ = std::make_shared( + channel_map_, runtime_context_->GetConfig().GetWriterConsumedStep()); + break; + default: + flow_controller_ = std::make_shared(); + break; + } + + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, this, nullptr); + // Register empty event and user event to event server. + event_service_ = std::make_shared(); + event_service_->Register( + EventType::EmptyEvent, + std::bind(&DataWriter::SendEmptyToChannel, this, std::placeholders::_1)); + event_service_->Register(EventType::UserEvent, std::bind(&DataWriter::WriteAllToChannel, + this, std::placeholders::_1)); + event_service_->Register(EventType::FlowEvent, std::bind(&DataWriter::WriteAllToChannel, + this, std::placeholders::_1)); + + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + return StreamingStatus::OK; +} + +void DataWriter::BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, + uint32_t data_size) { + STREAMING_LOG(INFO) << "broadcast checkpoint id : " << barrier_id; + barrier_helper_.MapBarrierToCheckpoint(barrier_id, barrier_id); + + if (barrier_helper_.Contains(barrier_id)) { + STREAMING_LOG(WARNING) << "replicated global barrier id => " << barrier_id; + return; + } + + std::vector barrier_id_vec; + barrier_helper_.GetAllBarrier(barrier_id_vec); + if (barrier_id_vec.size() > 0) { + // Show all stashed barrier ids that means these checkpoint are not finished + // yet. + STREAMING_LOG(WARNING) << "[Writer] [Barrier] previous barrier(checkpoint) was fail " + "to do some opearting, ids => " + << Util::join(barrier_id_vec.begin(), barrier_id_vec.end(), + "|"); + } + StreamingBarrierHeader barrier_header(StreamingBarrierType::GlobalBarrier, barrier_id); + + auto barrier_payload = + StreamingMessage::MakeBarrierPayload(barrier_header, data, data_size); + auto payload_size = kBarrierHeaderSize + data_size; + for (auto &queue_id : output_queue_ids_) { + uint64_t barrier_message_id = WriteMessageToBufferRing( + queue_id, barrier_payload.get(), payload_size, StreamingMessageType::Barrier); + if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Interrupted) { + STREAMING_LOG(WARNING) << " stop right now"; + return; + } + + STREAMING_LOG(INFO) << "[Writer] [Barrier] write barrier to => " << queue_id + << ", barrier message id =>" << barrier_message_id + << ", barrier id => " << barrier_id; + } + + STREAMING_LOG(INFO) << "[Writer] [Barrier] global barrier id in runtime => " + << barrier_id; +} + +DataWriter::DataWriter(std::shared_ptr &runtime_context) + : transfer_config_(new Config()), runtime_context_(runtime_context) {} + +DataWriter::~DataWriter() { + // Return if fail to init streaming writer + if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Init) { + return; + } + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); + if (event_service_) { + event_service_->Stop(); + if (empty_message_thread_->joinable()) { + STREAMING_LOG(INFO) << "Empty message thread waiting for join"; + empty_message_thread_->join(); + } + if (flow_control_thread_->joinable()) { + STREAMING_LOG(INFO) << "FlowControl timer thread waiting for join"; + flow_control_thread_->join(); + } + int user_event_count = 0; + int empty_event_count = 0; + int flow_control_event_count = 0; + int in_event_queue_cnt = 0; + int queue_full_cnt = 0; + for (auto &output_queue : output_queue_ids_) { + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + user_event_count += channel_info.user_event_cnt; + empty_event_count += channel_info.sent_empty_cnt; + flow_control_event_count += channel_info.flow_control_cnt; + in_event_queue_cnt += channel_info.in_event_queue_cnt; + queue_full_cnt += channel_info.queue_full_cnt; + } + STREAMING_LOG(WARNING) << "User event nums: " << user_event_count + << ", empty event nums: " << empty_event_count + << ", flow control event nums: " << flow_control_event_count + << ", queue full nums: " << queue_full_cnt + << ", in event queue: " << in_event_queue_cnt; + } + STREAMING_LOG(INFO) << "Writer client queue disconnect."; +} + +bool DataWriter::IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info) { + return channel_info.writer_ring_buffer->IsTransientAvaliable() || + !channel_info.writer_ring_buffer->IsEmpty(); +} + +StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) { + auto &q_id = channel_info.channel_id; + if (channel_info.message_last_commit_id < channel_info.current_message_id) { + // Abort to send empty message if ring buffer is not empty now. + STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " abort to send empty, last commit id =>" + << channel_info.message_last_commit_id << ", channel max id => " + << channel_info.current_message_id; + return StreamingStatus::SkipSendEmptyMessage; + } + + // Make an empty bundle, use old ts from reloaded meta if it's not nullptr. + StreamingMessageBundlePtr bundle_ptr = std::make_shared( + channel_info.current_message_id, current_time_ms()); + auto &q_ringbuffer = channel_info.writer_ring_buffer; + q_ringbuffer->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); + bundle_ptr->ToBytes(q_ringbuffer->GetTransientBufferMutable()); + + StreamingStatus status = channel_map_[q_id]->ProduceItemToChannel( + const_cast(q_ringbuffer->GetTransientBuffer()), + q_ringbuffer->GetTransientBufferSize()); + STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " send empty message, meta info =>" + << bundle_ptr->ToString(); + + q_ringbuffer->FreeTransientBuffer(); + RETURN_IF_NOT_OK(status) + channel_info.message_pass_by_ts = current_time_ms(); + return StreamingStatus::OK; +} + +StreamingStatus DataWriter::WriteTransientBufferToChannel( + ProducerChannelInfo &channel_info) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel( + buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize()); + RETURN_IF_NOT_OK(status) + auto transient_bundle_meta = + StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer()); + bool is_barrier_bundle = transient_bundle_meta->IsBarrier(); + // Force delete to avoid super block memory isn't released so long + // if it's barrier bundle. + buffer_ptr->FreeTransientBuffer(is_barrier_bundle); + channel_info.message_last_commit_id = transient_bundle_meta->GetLastMessageId(); + return StreamingStatus::OK; +} + +bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + auto &q_id = channel_info.channel_id; + + std::list message_list; + uint32_t bundle_buffer_size = 0; + const uint32_t max_queue_item_size = channel_info.queue_size; + + bool is_barrier = false; + + // Pop until one of the following condition meets: + // 1. ring buffer is empty + // 2. message count in bundle is larger than ring buffer size + // 3. sum of data size of messages in bundle is larger than streaming queue size + // 4. message type changed + while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() && + !buffer_ptr->IsEmpty()) { + StreamingMessagePtr &message_ptr = buffer_ptr->Front(); + STREAMING_LOG(DEBUG) << "Collecting message " << *message_ptr + << ", message_list_size=" << message_list.size() + << ", buffer capacity=" + << runtime_context_->GetConfig().GetRingBufferCapacity() + << ", buffer size=" << buffer_ptr->Size(); + + uint32_t message_total_size = message_ptr->ClassBytesSize(); + if (!message_list.empty() && + bundle_buffer_size + message_total_size >= max_queue_item_size) { + STREAMING_LOG(DEBUG) << "message total size " << message_total_size + << " max queue item size => " << max_queue_item_size; + break; + } + if (!message_list.empty() && + message_list.back()->GetMessageType() != message_ptr->GetMessageType()) { + STREAMING_LOG(DEBUG) << "Different message type detected, break collecting, last " + "message type in list=" + << static_cast(message_list.back()->GetMessageType()) + << ", current collecing message type=" + << static_cast(message_ptr->GetMessageType()); + break; + } + bundle_buffer_size += message_total_size; + message_list.push_back(message_ptr); + buffer_ptr->Pop(); + buffer_remain = buffer_ptr->Size(); + is_barrier = message_ptr->IsBarrier(); + STREAMING_LOG(DEBUG) << "Message " << *message_ptr + << " collected, message_list_size=" << message_list.size() + << ", buffer capacity=" + << runtime_context_->GetConfig().GetRingBufferCapacity() + << ", buffer size=" << buffer_ptr->Size(); + } + + if (bundle_buffer_size >= channel_info.queue_size) { + STREAMING_LOG(ERROR) << "bundle buffer is too large to store q id => " << q_id + << ", bundle size => " << bundle_buffer_size + << ", queue size => " << channel_info.queue_size; + } + + StreamingMessageBundlePtr bundle_ptr; + StreamingMessageBundleType bundleType = StreamingMessageBundleType::Bundle; + if (is_barrier) { + bundleType = StreamingMessageBundleType::Barrier; + } + bundle_ptr = std::make_shared( + std::move(message_list), current_time_ms(), message_list.back()->GetMessageId(), + bundleType, bundle_buffer_size); + + STREAMING_LOG(DEBUG) << "CollectFromRingBuffer done, bundle=" << *bundle_ptr; + + buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); + bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable()); + + STREAMING_CHECK(bundle_ptr->ClassBytesSize() == buffer_ptr->GetTransientBufferSize()); + return true; +} + +void DataWriter::Stop() { + for (auto &output_queue : output_queue_ids_) { + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + while (!channel_info.writer_ring_buffer->IsEmpty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); +} + +bool DataWriter::WriteAllToChannel(ProducerChannelInfo *info) { + ProducerChannelInfo &channel_info = *info; + channel_info.in_event_queue = false; + while (true) { + if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) { + return false; + } + // Stop to write remained messages to channel if channel has been blocked by + // flow control. + if (channel_info.flow_control) { + break; + } + // Check this channel is blocked by flow control or not. + if (flow_controller_->ShouldFlowControl(channel_info)) { + channel_info.flow_control = true; + break; + } + uint64_t ring_buffer_remain = channel_info.writer_ring_buffer->Size(); + StreamingStatus write_status = WriteBufferToChannel(channel_info, ring_buffer_remain); + int64_t current_ts = current_time_ms(); + if (StreamingStatus::OK == write_status) { + channel_info.message_pass_by_ts = current_ts; + } else if (StreamingStatus::FullChannel == write_status || + StreamingStatus::OutOfMemory == write_status) { + channel_info.flow_control = true; + ++channel_info.queue_full_cnt; + STREAMING_LOG(DEBUG) << "FullChannel after writing to channel, queue_full_cnt:" + << channel_info.queue_full_cnt; + RefreshChannelAndNotifyConsumed(channel_info); + } else if (StreamingStatus::EmptyRingBuffer != write_status) { + STREAMING_LOG(INFO) << channel_info.channel_id + << ":something wrong when WriteToQueue " + << "write buffer status => " + << static_cast(write_status); + break; + } + if (ring_buffer_remain == 0 && + !channel_info.writer_ring_buffer->IsTransientAvaliable()) { + break; + } + } + return true; +} + +bool DataWriter::SendEmptyToChannel(ProducerChannelInfo *channel_info) { + WriteEmptyMessage(*channel_info); + return true; +} + +void DataWriter::EmptyMessageTimerCallback() { + while (true) { + if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) { + return; + } + + int64_t current_ts = current_time_ms(); + int64_t min_passby_message_ts = current_ts; + int count = 0; + for (auto output_queue : output_queue_ids_) { + if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) { + return; + } + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + if (channel_info.flow_control || channel_info.writer_ring_buffer->Size() || + current_ts < channel_info.message_pass_by_ts) { + continue; + } + if (current_ts - channel_info.message_pass_by_ts >= + runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) { + Event event(&channel_info, EventType::EmptyEvent, true); + event_service_->Push(event); + ++channel_info.sent_empty_cnt; + ++count; + continue; + } + if (min_passby_message_ts > channel_info.message_pass_by_ts) { + min_passby_message_ts = channel_info.message_pass_by_ts; + } + } + STREAMING_LOG(DEBUG) << "EmptyThd:produce empty_events:" << count + << " eventqueue size:" << event_service_->EventNums() + << " next_sleep_time:" + << runtime_context_->GetConfig().GetEmptyMessageTimeInterval() - + current_ts + min_passby_message_ts; + + for (const auto &output_queue : output_queue_ids_) { + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + STREAMING_LOG(DEBUG) << output_queue << "==ring_buffer size:" + << channel_info.writer_ring_buffer->Size() + << " transient_buffer size:" + << channel_info.writer_ring_buffer->GetTransientBufferSize() + << " in_event_queue:" << channel_info.in_event_queue + << " flow_control:" << channel_info.flow_control + << " user_event_cnt:" << channel_info.user_event_cnt + << " flow_control_event:" << channel_info.flow_control_cnt + << " empty_event_cnt:" << channel_info.sent_empty_cnt + << " rb_full_cnt:" << channel_info.rb_full_cnt + << " queue_full_cnt:" << channel_info.queue_full_cnt; + } + + std::this_thread::sleep_for(std::chrono::milliseconds( + runtime_context_->GetConfig().GetEmptyMessageTimeInterval() - current_ts + + min_passby_message_ts)); + } +} + +void DataWriter::RefreshChannelAndNotifyConsumed(ProducerChannelInfo &channel_info) { + // Refresh current downstream consumed seq id. + channel_map_[channel_info.channel_id]->RefreshChannelInfo(); + // Notify the consumed information to local channel. + NotifyConsumedItem(channel_info, channel_info.queue_info.consumed_message_id); +} + +void DataWriter::NotifyConsumedItem(ProducerChannelInfo &channel_info, uint32_t offset) { + if (offset > channel_info.current_message_id) { + STREAMING_LOG(WARNING) << "Can not notify consumed this offset " << offset + << " that's out of range, max seq id " + << channel_info.current_message_id; + } else { + channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); + } +} + +void DataWriter::FlowControlTimer() { + std::chrono::milliseconds MockTimer( + runtime_context_->GetConfig().GetEventDrivenFlowControlInterval()); + while (true) { + if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) { + return; + } + for (const auto &output_queue : output_queue_ids_) { + if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) { + return; + } + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + if (!channel_info.flow_control) { + continue; + } + if (!flow_controller_->ShouldFlowControl(channel_info)) { + channel_info.flow_control = false; + Event event{&channel_info, EventType::FlowEvent, + channel_info.writer_ring_buffer->IsFull()}; + event_service_->Push(event); + ++channel_info.flow_control_cnt; + } + } + std::this_thread::sleep_for(MockTimer); + } +} + +void DataWriter::GetOffsetInfo( + std::unordered_map *&offset_map) { + offset_map = &channel_info_map_; +} + +void DataWriter::ClearCheckpoint(uint64_t barrier_id) { + if (!barrier_helper_.Contains(barrier_id)) { + STREAMING_LOG(WARNING) << "no such barrier id => " << barrier_id; + return; + } + + std::string global_barrier_id_list_str = "|"; + + for (auto &queue_id : output_queue_ids_) { + uint64_t q_global_barrier_msg_id = 0; + StreamingStatus status = barrier_helper_.GetMsgIdByBarrierId(queue_id, barrier_id, + q_global_barrier_msg_id); + ProducerChannelInfo &channel_info = channel_info_map_[queue_id]; + if (status == StreamingStatus::OK) { + ClearCheckpointId(channel_info, q_global_barrier_msg_id); + } else { + STREAMING_LOG(WARNING) << "no seq record in q => " << queue_id << ", barrier id => " + << barrier_id; + } + global_barrier_id_list_str += + queue_id.Hex() + " : " + std::to_string(q_global_barrier_msg_id) + "| "; + reliability_helper_->CleanupCheckpoint(channel_info, barrier_id); + } + + STREAMING_LOG(INFO) + << "[Writer] [Barrier] [clear] global barrier flag, global barrier id => " + << barrier_id << ", seq id map => " << global_barrier_id_list_str; + + barrier_helper_.ReleaseBarrierMapById(barrier_id); + barrier_helper_.ReleaseBarrierMapCheckpointByBarrierId(barrier_id); +} + +void DataWriter::ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t msg_id) { + AutoSpinLock lock(notify_flag_); + + uint64_t current_msg_id = channel_info.current_message_id; + if (msg_id > current_msg_id) { + STREAMING_LOG(WARNING) << "current_msg_id=" << current_msg_id + << ", msg_id to be cleared=" << msg_id + << ", channel id = " << channel_info.channel_id; + } + channel_map_[channel_info.channel_id]->NotifyChannelConsumed(msg_id); + + STREAMING_LOG(DEBUG) << "clearing data from msg_id=" << msg_id + << ", qid= " << channel_info.channel_id; +} + +void DataWriter::GetChannelOffset(std::vector &result) { + for (auto &q_id : output_queue_ids_) { + result.push_back(channel_info_map_[q_id].current_message_id); + } +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h new file mode 100644 index 00000000..2b1582d9 --- /dev/null +++ b/streaming/src/data_writer.h @@ -0,0 +1,188 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "channel/channel.h" +#include "config/streaming_config.h" +#include "event_service.h" +#include "flow_control.h" +#include "message/message_bundle.h" +#include "reliability/barrier_helper.h" +#include "reliability_helper.h" +#include "runtime_context.h" + +namespace ray { +namespace streaming { +class ReliabilityHelper; + +/// DataWriter is designed for data transporting between upstream and downstream. +/// After the user sends the data, it does not immediately send the data to +/// downstream, but caches it in the corresponding memory ring buffer. There is +/// a spearate transfer thread (setup in WriterLoopForward function) to collect +/// the messages from all the ringbuffers, and write them to the corresponding +/// transmission channels, which is backed by StreamingQueue. Actually, the +/// advantage is that the user thread will not be affected by the transmission +/// speed during the data transfer. And also the transfer thread can automatically +/// batch the catched data from memory buffer into a data bundle to reduce +/// transmission overhead. In addtion, when there is no data in the ringbuffer, +/// it will also send an empty bundle, so downstream can know that and process +/// accordingly. It will sleep for a short interval to save cpu if all ring +/// buffers have no data in that moment. +class DataWriter { + public: + // For mock writer accessing inner fields. + friend class MockWriter; + + explicit DataWriter(std::shared_ptr &runtime_context); + virtual ~DataWriter(); + + /// Streaming writer client initialization. + /// \param queue_id_vec queue id vector + /// \param init_params some parameters for initializing channels + /// \param channel_message_id_vec channel seq id is related with message checkpoint + /// \param queue_size queue size (memory size not length) + StreamingStatus Init(const std::vector &channel_ids, + const std::vector &init_params, + const std::vector &channel_message_id_vec, + const std::vector &queue_size_vec); + + /// To increase throughout, we employed an output buffer for message transformation, + /// which means we merge a lot of message to a message bundle and no message will be + /// pushed into queue directly util daemon thread does this action. + /// Additionally, writing will block when buffer ring is full intentionly. + /// \param q_id, destination channel id + /// \param data, pointer of raw data + /// \param data_size, raw data size + /// \param message_type + /// \return message seq iq + uint64_t WriteMessageToBufferRing( + const ObjectID &q_id, uint8_t *data, uint32_t data_size, + StreamingMessageType message_type = StreamingMessageType::Message); + + /// Send barrier to all channel. note there are user defined data in barrier bundle + /// \param barrier_id + /// \param data + /// \param data_size + /// + void BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, uint32_t data_size); + + /// To relieve stress from large source/input data, we define a new function + /// clear_check_point + /// in producer/writer class. Worker can invoke this function if and only if + /// notify_consumed each item + /// flag is passed in reader/consumer, which means writer's producing became more + /// rhythmical and reader + /// can't walk on old way anymore. + /// \param barrier_id: user-defined numerical checkpoint id + void ClearCheckpoint(uint64_t barrier_id); + + /// Replay all queue from checkpoint, it's useful under FO + /// \param result offset vector + void GetChannelOffset(std::vector &result); + + void Run(); + + void Stop(); + + /// Get offset information about channels for checkpoint. + /// \param offset_map (return value) + void GetOffsetInfo(std::unordered_map *&offset_map); + + private: + bool IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info); + + /// This function handles two scenarios. When there is data in the transient + /// buffer, the existing data is written into the channel first, otherwise a + /// certain amount of message is first collected from the buffer and serialized + /// into the transient buffer, and finally written to the channel. + /// \\param channel_info + /// \\param buffer_remain + StreamingStatus WriteBufferToChannel(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain); + + /// Push empty message when no valid message or bundle was produced each time + /// interval. + /// \param channel_info + StreamingStatus WriteEmptyMessage(ProducerChannelInfo &channel_info); + + /// Flush all data from transient buffer to channel for transporting. + /// \param channel_info + StreamingStatus WriteTransientBufferToChannel(ProducerChannelInfo &channel_info); + + bool CollectFromRingBuffer(ProducerChannelInfo &channel_info, uint64_t &buffer_remain); + + StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info, + bool *is_empty_message); + + StreamingStatus InitChannel(const ObjectID &q_id, const ChannelCreationParameter ¶m, + uint64_t channel_message_id, uint64_t queue_size); + + /// Write all messages to channel util ringbuffer is empty. + /// \param channel_info + bool WriteAllToChannel(ProducerChannelInfo *channel_info); + + /// Trigger an empty message for channel with no valid data. + /// \param channel_info + bool SendEmptyToChannel(ProducerChannelInfo *channel_info); + + void EmptyMessageTimerCallback(); + + /// Notify channel consumed refreshing downstream queue stats. + void RefreshChannelAndNotifyConsumed(ProducerChannelInfo &channel_info); + + /// Notify channel consumed by given offset. + void NotifyConsumedItem(ProducerChannelInfo &channel_info, uint32_t offset); + + void FlowControlTimer(); + + void ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t seq_id); + + private: + std::shared_ptr event_service_; + + std::shared_ptr empty_message_thread_; + + std::shared_ptr flow_control_thread_; + // One channel have unique identity. + std::vector output_queue_ids_; + // Flow controller makes a decision when it's should be blocked and avoid + // unnecessary overflow. + std::shared_ptr flow_controller_; + + StreamingBarrierHelper barrier_helper_; + std::shared_ptr reliability_helper_; + + // Make thread-safe between loop thread and user thread. + // High-level runtime send notification about clear checkpoint if global + // checkpoint is finished and low-level will auto flush & evict item memory + // when no more space is available. + std::atomic_flag notify_flag_ = ATOMIC_FLAG_INIT; + + protected: + std::unordered_map channel_info_map_; + /// ProducerChannel is middle broker for data transporting and all downstream + /// producer channels will be channel_map_. + std::unordered_map> channel_map_; + std::shared_ptr transfer_config_; + std::shared_ptr runtime_context_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/event_service.cc b/streaming/src/event_service.cc new file mode 100644 index 00000000..dc1003f9 --- /dev/null +++ b/streaming/src/event_service.cc @@ -0,0 +1,216 @@ +#include "event_service.h" + +#include +#include + +namespace ray { +namespace streaming { + +EventQueue::~EventQueue() { + is_active_ = false; + no_full_cv_.notify_all(); + no_empty_cv_.notify_all(); +}; + +void EventQueue::Unfreeze() { is_active_ = true; } + +void EventQueue::Freeze() { + is_active_ = false; + no_empty_cv_.notify_all(); + no_full_cv_.notify_all(); +} + +void EventQueue::Push(const Event &t) { + std::unique_lock lock(ring_buffer_mutex_); + while (Size() >= capacity_ && is_active_) { + STREAMING_LOG(WARNING) << " EventQueue is full, its size:" << Size() + << " capacity:" << capacity_ + << " buffer size:" << buffer_.size() + << " urgent_buffer size:" << urgent_buffer_.size(); + no_full_cv_.wait(lock); + STREAMING_LOG(WARNING) << "Event server is full_sleep be notified"; + } + if (!is_active_) { + return; + } + if (t.urgent) { + buffer_.push(t); + } else { + urgent_buffer_.push(t); + } + if (1 == Size()) { + no_empty_cv_.notify_one(); + } +} + +void EventQueue::Pop() { + std::unique_lock lock(ring_buffer_mutex_); + if (Size() >= capacity_) { + STREAMING_LOG(WARNING) << "Pop should notify" + << " size : " << Size(); + } + if (urgent_) { + urgent_buffer_.pop(); + } else { + buffer_.pop(); + } + no_full_cv_.notify_all(); +} + +constexpr int EventQueue::kConditionTimeoutMs; +void EventQueue::WaitFor(std::unique_lock &lock) { + // To avoid deadlock when EventQueue is empty but is_active is changed in other + // thread, Event queue should awaken this condtion variable and check it again. + while (is_active_ && Empty()) { + int timeout = kConditionTimeoutMs; // This avoids const & to static (linking error) + if (!no_empty_cv_.wait_for(lock, std::chrono::milliseconds(timeout), + [this]() { return !is_active_ || !Empty(); })) { + STREAMING_LOG(DEBUG) << "No empty condition variable wait timeout." + << " Empty => " << Empty() << ", is active " << is_active_; + } + } +} + +bool EventQueue::Get(Event &evt) { + std::unique_lock lock(ring_buffer_mutex_); + WaitFor(lock); + if (!is_active_) { + return false; + } + if (!urgent_buffer_.empty()) { + urgent_ = true; + evt = urgent_buffer_.front(); + } else { + urgent_ = false; + evt = buffer_.front(); + } + return true; +} + +Event EventQueue::PopAndGet() { + std::unique_lock lock(ring_buffer_mutex_); + WaitFor(lock); + if (!is_active_) { + // Return error event if queue is active. + return Event({nullptr, EventType::ErrorEvent, false}); + } + if (!urgent_buffer_.empty()) { + Event res = urgent_buffer_.front(); + urgent_buffer_.pop(); + if (Full()) { + no_full_cv_.notify_one(); + } + return res; + } + Event res = buffer_.front(); + buffer_.pop(); + if (Size() + 1 == capacity_) no_full_cv_.notify_one(); + return res; +} + +Event &EventQueue::Front() { + std::unique_lock lock(ring_buffer_mutex_); + if (urgent_buffer_.size()) { + return urgent_buffer_.front(); + } + return buffer_.front(); +} + +EventService::EventService(uint32_t event_size) + : worker_id_(CoreWorkerProcess::IsInitialized() + ? CoreWorkerProcess::GetCoreWorker().GetWorkerID() + : WorkerID::Nil()), + event_queue_(std::make_shared(event_size)), + stop_flag_(false) {} +EventService::~EventService() { + stop_flag_ = true; + // No need to join if loop thread has never been created. + if (loop_thread_ && loop_thread_->joinable()) { + STREAMING_LOG(WARNING) << "Loop Thread Stopped"; + loop_thread_->join(); + } +} + +void EventService::Run() { + stop_flag_ = false; + event_queue_->Unfreeze(); + loop_thread_ = std::make_shared(&EventService::LoopThreadHandler, this); + STREAMING_LOG(WARNING) << "event_server run"; +} + +void EventService::Stop() { + stop_flag_ = true; + event_queue_->Freeze(); + if (loop_thread_->joinable()) { + loop_thread_->join(); + } + STREAMING_LOG(WARNING) << "event_server stop"; +} + +bool EventService::Register(const EventType &type, const Handle &handle) { + if (event_handle_map_.find(type) != event_handle_map_.end()) { + STREAMING_LOG(WARNING) << "EventType had been registered!"; + } + event_handle_map_[type] = handle; + return true; +} + +void EventService::Push(const Event &event) { event_queue_->Push(event); } + +void EventService::Execute(Event &event) { + if (event_handle_map_.find(event.type) == event_handle_map_.end()) { + STREAMING_LOG(WARNING) << "Handle has never been registered yet, type => " + << static_cast(event.type); + return; + } + Handle &handle = event_handle_map_[event.type]; + if (handle(event.channel_info)) { + event_queue_->Pop(); + } +} + +void EventService::LoopThreadHandler() { + if (CoreWorkerProcess::IsInitialized()) { + CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id_); + } + while (true) { + if (stop_flag_) { + break; + } + Event event; + if (event_queue_->Get(event)) { + Execute(event); + } + } +} + +void EventService::RemoveDestroyedChannelEvent(const std::vector &removed_ids) { + // NOTE(lingxuan.zlx): To prevent producing invalid event for removed + // channels, we pop out all invalid channel related events(push it to + // original queue if it has no connection with removed channels). + std::unordered_set removed_set(removed_ids.begin(), removed_ids.end()); + size_t total_event_nums = EventNums(); + STREAMING_LOG(INFO) << "Remove Destroyed channel event, removed_ids size " + << removed_ids.size() << ", total event size " << total_event_nums; + size_t removed_related_num = 0; + event_queue_->Unfreeze(); + for (size_t i = 0; i < total_event_nums; ++i) { + Event event; + if (!event_queue_->Get(event) || !event.channel_info) { + STREAMING_LOG(WARNING) << "Fail to get event or channel_info is null, i = " << i; + continue; + } + if (removed_set.find(event.channel_info->channel_id) != removed_set.end()) { + removed_related_num++; + } else { + event_queue_->Push(event); + } + event_queue_->Pop(); + } + event_queue_->Freeze(); + STREAMING_LOG(INFO) << "Total event num => " << total_event_nums + << ", removed related num => " << removed_related_num; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/event_service.h b/streaming/src/event_service.h new file mode 100644 index 00000000..874fff02 --- /dev/null +++ b/streaming/src/event_service.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "channel/channel.h" +#include "ring_buffer/ring_buffer.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +enum class EventType : uint8_t { + // A message created by user writing. + UserEvent = 0, + // Unblock upstream writing when it's under flowcontrol. + FlowEvent = 1, + // Trigger an empty message by timer. + EmptyEvent = 2, + FullChannel = 3, + // Recovery at the beginning. + Reload = 4, + // Error event if event queue is not active. + ErrorEvent = 5 +}; + +struct EnumTypeHash { + template + std::size_t operator()(const T &t) const { + return static_cast(t); + } +}; + +struct Event { + ProducerChannelInfo *channel_info; + EventType type; + bool urgent; + Event() = default; + Event(ProducerChannelInfo *channel_info, EventType type, bool urgent) { + this->channel_info = channel_info; + this->type = type; + this->urgent = urgent; + } +}; + +/// Data writer utilizes what's called an event-driven programming model +/// that includes two important components: event service and event +/// queue. In the process of data transmission, the writer will first define +/// the processing method of corresponding events. However, by triggering +/// different events in actual operation, these events will be put into the event +/// queue, and finally the event server will schedule the previously registered +/// processing functions ordered by its priority. +class EventQueue { + public: + EventQueue(size_t size) : urgent_(false), capacity_(size), is_active_(true) {} + + virtual ~EventQueue(); + + /// Resume event queue to normal model. + void Unfreeze(); + + /// Push is prohibited when event queue is not active. + void Freeze(); + + void Push(const Event &t); + + void Pop(); + + bool Get(Event &evt); + + Event PopAndGet(); + + Event &Front(); + + inline size_t Capacity() const { return capacity_; } + + /// It mainly divides event into two different levels: normal event and urgent + /// event, and the total size of the queue is the sum of them. + inline size_t Size() const { return buffer_.size() + urgent_buffer_.size(); } + + private: + /// (NOTE:lingxuan.zlx) There is no strict thread-safe when query empty or full, + /// but it can reduce lock contention. In fact, these functions are thread-safe + /// when invoked via Push/Pop where buffer size will only be changed in whole process. + inline bool Empty() const { return buffer_.empty() && urgent_buffer_.empty(); } + + inline bool Full() const { return buffer_.size() + urgent_buffer_.size() == capacity_; } + + /// Wait for queue util it's timeout or any stuff in. + void WaitFor(std::unique_lock &lock); + + private: + std::mutex ring_buffer_mutex_; + std::condition_variable no_empty_cv_; + std::condition_variable no_full_cv_; + // Normal events wil be pushed into buffer_. + std::queue buffer_; + // This field urgent_buffer_ is used for serving urgent event. + std::queue urgent_buffer_; + // Urgent event will be poped out first if urgent_ flag is true. + bool urgent_; + size_t capacity_; + // Event service active flag. + bool is_active_; + // Pop/Get timeout ms for condition variables wait. + static constexpr int kConditionTimeoutMs = 200; +}; + +class EventService { + public: + /// User-define event handle for different types. + typedef std::function Handle; + + EventService(uint32_t event_size = kEventQueueCapacity); + + ~EventService(); + + void Run(); + + void Stop(); + + bool Register(const EventType &type, const Handle &handle); + + void Push(const Event &event); + + inline size_t EventNums() const { return event_queue_->Size(); } + + void RemoveDestroyedChannelEvent(const std::vector &removed_ids); + + private: + void Execute(Event &event); + + /// A single thread should be invoked to run this loop function, so that + /// event server can poll and execute registered callback function event + /// one by one. + void LoopThreadHandler(); + + private: + WorkerID worker_id_; + std::unordered_map event_handle_map_; + std::shared_ptr event_queue_; + std::shared_ptr loop_thread_; + + static constexpr int kEventQueueCapacity = 1000; + + bool stop_flag_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/flow_control.cc b/streaming/src/flow_control.cc new file mode 100644 index 00000000..b49d10c8 --- /dev/null +++ b/streaming/src/flow_control.cc @@ -0,0 +1,35 @@ +#include "flow_control.h" + +namespace ray { +namespace streaming { + +UnconsumedSeqFlowControl::UnconsumedSeqFlowControl( + std::unordered_map> &channel_map, + uint32_t step) + : channel_map_(channel_map), consumed_step_(step){}; + +bool UnconsumedSeqFlowControl::ShouldFlowControl(ProducerChannelInfo &channel_info) { + auto &queue_info = channel_info.queue_info; + if (queue_info.target_message_id <= channel_info.current_message_id) { + channel_map_[channel_info.channel_id]->RefreshChannelInfo(); + // Target seq id is maximum upper limit in current condition. + channel_info.queue_info.target_message_id = + channel_info.queue_info.consumed_message_id + consumed_step_; + STREAMING_LOG(DEBUG) + << "Flow control stop writing to downstream, current message id => " + << channel_info.current_message_id << ", target message id => " + << queue_info.target_message_id << ", consumed_id => " + << queue_info.consumed_message_id << ", q id => " << channel_info.channel_id + << ". if this log keeps printing, it means something wrong " + "with queue's info API, or downstream node is not " + "consuming data."; + // Double check after refreshing if target seq id is changed. + if (queue_info.target_message_id <= channel_info.current_message_id) { + return true; + } + } + return false; +} +} // namespace streaming + +} // namespace ray diff --git a/streaming/src/flow_control.h b/streaming/src/flow_control.h new file mode 100644 index 00000000..005e75b8 --- /dev/null +++ b/streaming/src/flow_control.h @@ -0,0 +1,44 @@ +#pragma once + +#include "channel/channel.h" + +namespace ray { +namespace streaming { +class ProducerTransfer; +/// We devise a flow control system in queue channel, and that's called flow +/// control by unconsumed seq. Upstream worker will detect consumer statistics via +/// api so it can keep fixed length messages in this process, which makes a +/// continuous datastream in channel or on the transporting way, then downstream +/// can read them from channel immediately. +/// To debug or compare with theses flow control methods, we also support +/// no-flow-control that will do nothing in transporting. +class FlowControl { + public: + virtual ~FlowControl() = default; + virtual bool ShouldFlowControl(ProducerChannelInfo &channel_info) = 0; +}; + +class NoFlowControl : public FlowControl { + public: + bool ShouldFlowControl(ProducerChannelInfo &channel_info) { return false; } + ~NoFlowControl() = default; +}; + +class UnconsumedSeqFlowControl : public FlowControl { + public: + UnconsumedSeqFlowControl( + std::unordered_map> &channel_map, + uint32_t step); + ~UnconsumedSeqFlowControl() = default; + bool ShouldFlowControl(ProducerChannelInfo &channel_info); + + private: + /// NOTE(wanxing.wwx) Reference to channel_map_ variable in DataWriter. + /// Flow-control is checked in FlowControlThread, so channel_map_ is accessed + /// in multithread situation. Especially, while rescaling, channel_map_ maybe + /// changed. But for now, FlowControlThread is stopped before rescaling. + std::unordered_map> &channel_map_; + uint32_t consumed_step_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc new file mode 100644 index 00000000..ba4a4d40 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc @@ -0,0 +1,133 @@ +#include "io_ray_streaming_runtime_transfer_DataReader.h" + +#include + +#include "data_reader.h" +#include "runtime_context.h" +#include "streaming_jni_common.h" + +using namespace ray; +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( + JNIEnv *env, jclass, jobject streaming_queue_initial_parameters, + jobjectArray input_channels, jlongArray msg_id_array, jlong timer_interval, + jobject creation_status, jbyteArray config_bytes, jboolean is_mock) { + STREAMING_LOG(INFO) << "[JNI]: create DataReader."; + std::vector parameter_vec; + ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec); + std::vector input_channels_ids = + jarray_to_object_id_vec(env, input_channels); + std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data; + + auto ctx = std::make_shared(); + RawDataFromJByteArray conf(env, config_bytes); + if (conf.data_size > 0) { + STREAMING_LOG(INFO) << "load config, config bytes size: " << conf.data_size; + ctx->SetConfig(conf.data, conf.data_size); + } + if (is_mock) { + ctx->MarkMockTest(); + } + + // init reader + auto reader = new DataReader(ctx); + std::vector creation_status_vec; + reader->Init(input_channels_ids, parameter_vec, msg_ids, creation_status_vec, + timer_interval); + + // add creation status to Java's List + jclass array_list_cls = env->GetObjectClass(creation_status); + jclass integer_cls = env->FindClass("java/lang/Integer"); + jmethodID array_list_add = + env->GetMethodID(array_list_cls, "add", "(Ljava/lang/Object;)Z"); + for (auto &status : creation_status_vec) { + jmethodID integer_init = env->GetMethodID(integer_cls, "", "(I)V"); + jobject integer_obj = + env->NewObject(integer_cls, integer_init, static_cast(status)); + env->CallBooleanMethod(creation_status, array_list_add, integer_obj); + } + STREAMING_LOG(INFO) << "create native DataReader succeed"; + return reinterpret_cast(reader); +} + +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative( + JNIEnv *env, jobject, jlong reader_ptr, jlong timeout_millis, jlong out, + jlong meta_addr) { + std::shared_ptr bundle; + auto reader = reinterpret_cast(reader_ptr); + auto status = reader->GetBundle((uint32_t)timeout_millis, bundle); + + // over timeout, return empty array. + if (StreamingStatus::Interrupted == status) { + throwChannelInterruptException(env, "reader interrupted."); + } else if (StreamingStatus::GetBundleTimeOut == status) { + } else if (StreamingStatus::InitQueueFailed == status) { + throwRuntimeException(env, "init channel failed"); + } + + if (StreamingStatus::OK != status) { + *reinterpret_cast(out) = 0; + *reinterpret_cast(out + 8) = 0; + return; + } + + // bundle data + // In streaming queue, bundle data and metadata will be different args of direct call, + // so we separate it here for future extensibility. + *reinterpret_cast(out) = + reinterpret_cast(bundle->data + kMessageBundleHeaderSize); + *reinterpret_cast(out + 8) = bundle->data_size - kMessageBundleHeaderSize; + + // bundle metadata + auto meta = reinterpret_cast(meta_addr); + // bundle header written by writer + std::memcpy(meta, bundle->data, kMessageBundleHeaderSize); + // append qid + std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize); +} + +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative( + JNIEnv *env, jobject thisObj, jlong ptr) { + auto reader = reinterpret_cast(ptr); + reader->Stop(); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + auto reader = reinterpret_cast(ptr); + std::unordered_map *offset_map = nullptr; + reader->GetOffsetInfo(offset_map); + STREAMING_CHECK(offset_map); + // queue nums + (queue id + seq id + message id) * queue nums + int offset_data_size = + sizeof(uint32_t) + (kUniqueIDSize + sizeof(uint64_t) * 2) * offset_map->size(); + jbyteArray offsets_info = env->NewByteArray(offset_data_size); + int offset = 0; + // total queue nums + auto queue_nums = static_cast(offset_map->size()); + env->SetByteArrayRegion(offsets_info, offset, sizeof(uint32_t), + reinterpret_cast(&queue_nums)); + offset += sizeof(uint32_t); + // queue name & offset + for (auto &p : *offset_map) { + env->SetByteArrayRegion(offsets_info, offset, kUniqueIDSize, + reinterpret_cast(p.first.Data())); + offset += kUniqueIDSize; + // msg_id + env->SetByteArrayRegion(offsets_info, offset, sizeof(uint64_t), + reinterpret_cast(&p.second.current_message_id)); + offset += sizeof(uint64_t); + } + return offsets_info; +} \ No newline at end of file diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h new file mode 100644 index 00000000..43f677d3 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h @@ -0,0 +1,72 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_DataReader */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_DataReader +#define _Included_io_ray_streaming_runtime_transfer_DataReader +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: createDataReaderNative + * Signature: + * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJLjava/util/List;[BZ)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jobject, jbyteArray, + jboolean); + +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: getBundleNative + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative( + JNIEnv *, jobject, jlong, jlong, jlong, jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: getOffsetsInfoNative + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *, jobject, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: stopReaderNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_stopReaderNative( + JNIEnv *, jobject, jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: closeReaderNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *, jobject, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc new file mode 100644 index 00000000..efccdf98 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc @@ -0,0 +1,118 @@ +#include "io_ray_streaming_runtime_transfer_DataWriter.h" + +#include "config/streaming_config.h" +#include "data_writer.h" +#include "streaming_jni_common.h" + +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( + JNIEnv *env, jclass, jobject initial_parameters, jobjectArray output_queue_ids, + jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array, + jboolean is_mock) { + STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative."; + + std::vector parameter_vec; + ParseChannelInitParameters(env, initial_parameters, parameter_vec); + std::vector queue_id_vec = + jarray_to_object_id_vec(env, output_queue_ids); + for (auto id : queue_id_vec) { + STREAMING_LOG(INFO) << "output channel id: " << id.Hex(); + } + STREAMING_LOG(INFO) << "total channel size: " << channel_size << "*" + << queue_id_vec.size() << "=" << queue_id_vec.size() * channel_size; + LongVectorFromJLongArray long_array_obj(env, msg_ids); + std::vector msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data; + std::vector queue_size_vec(long_array_obj.data.size(), channel_size); + std::vector remain_id_vec; + + RawDataFromJByteArray conf(env, conf_bytes_array); + STREAMING_CHECK(conf.data != nullptr); + auto runtime_context = std::make_shared(); + if (conf.data_size > 0) { + runtime_context->SetConfig(conf.data, conf.data_size); + } + if (is_mock) { + runtime_context->MarkMockTest(); + } + auto *data_writer = new DataWriter(runtime_context); + auto status = + data_writer->Init(queue_id_vec, parameter_vec, msg_ids_vec, queue_size_vec); + if (status != StreamingStatus::OK) { + STREAMING_LOG(WARNING) << "DataWriter init failed."; + } else { + STREAMING_LOG(INFO) << "DataWriter init success"; + } + + data_writer->Run(); + return reinterpret_cast(data_writer); +} + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative( + JNIEnv *env, jobject, jlong writer_ptr, jlong qid_ptr, jlong address, jint size) { + auto *data_writer = reinterpret_cast(writer_ptr); + auto qid = *reinterpret_cast(qid_ptr); + auto data = reinterpret_cast(address); + auto data_size = static_cast(size); + jlong result = data_writer->WriteMessageToBufferRing(qid, data, data_size, + StreamingMessageType::Message); + + if (result == 0) { + STREAMING_LOG(INFO) << "writer interrupted, return 0."; + throwChannelInterruptException(env, "writer interrupted."); + } + return result; +} + +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative( + JNIEnv *env, jobject thisObj, jlong ptr) { + STREAMING_LOG(INFO) << "jni: stop writer."; + auto *data_writer = reinterpret_cast(ptr); + data_writer->Stop(); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + auto *data_writer = reinterpret_cast(ptr); + delete data_writer; +} + +JNIEXPORT jlongArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + DataWriter *writer_client = reinterpret_cast(ptr); + + std::vector result; + writer_client->GetChannelOffset(result); + + jlongArray jArray = env->NewLongArray(result.size()); + jlong jdata[result.size()]; + for (size_t i = 0; i < result.size(); ++i) { + *(jdata + i) = result[i]; + } + env->SetLongArrayRegion(jArray, 0, result.size(), jdata); + return jArray; +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative( + JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId, jbyteArray data) { + STREAMING_LOG(INFO) << "jni: broadcast barrier, cp_id=" << checkpointId; + RawDataFromJByteArray raw_data(env, data); + DataWriter *writer_client = reinterpret_cast(ptr); + writer_client->BroadcastBarrier(checkpointId, raw_data.data, raw_data.data_size); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative( + JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId) { + STREAMING_LOG(INFO) << "[Producer] jni: clearCheckpoints."; + auto *writer = reinterpret_cast(ptr); + writer->ClearCheckpoint(checkpointId); + STREAMING_LOG(INFO) << "[Producer] clear checkpoint done."; +} diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h new file mode 100644 index 00000000..ff6ebb83 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h @@ -0,0 +1,93 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_DataWriter */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_DataWriter +#define _Included_io_ray_streaming_runtime_transfer_DataWriter +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: createWriterNative + * Signature: + * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJ[BZ)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_createWriterNative( + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jbyteArray, jboolean); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: writeMessageNative + * Signature: (JJJI)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(JNIEnv *, jobject, + jlong, jlong, jlong, + jint); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: stopWriterNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_stopWriterNative( + JNIEnv *, jobject, jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: closeWriterNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: getOutputMsgIdNative + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *, jobject, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: broadcastBarrierNative + * Signature: (JJJ[B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative(JNIEnv *, + jobject, jlong, + jlong, + jbyteArray); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: clearCheckpointNative + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative(JNIEnv *, jobject, + jlong, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc new file mode 100644 index 00000000..39c8f539 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc @@ -0,0 +1,66 @@ +#include "io_ray_streaming_runtime_transfer_TransferHandler.h" + +#include "queue/queue_client.h" +#include "streaming_jni_common.h" + +using namespace ray::streaming; + +static std::shared_ptr JByteArrayToBuffer(JNIEnv *env, + jbyteArray bytes) { + RawDataFromJByteArray buf(env, bytes); + STREAMING_CHECK(buf.data != nullptr); + + return std::make_shared(buf.data, buf.data_size, true); +} + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( + JNIEnv *env, jobject this_obj) { + auto *writer_client = new WriterClient(); + return reinterpret_cast(writer_client); +} + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative( + JNIEnv *env, jobject this_obj) { + auto *reader_client = new ReaderClient(); + return reinterpret_cast(reader_client); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *writer_client = reinterpret_cast(ptr); + writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes)); +} + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *writer_client = reinterpret_cast(ptr); + std::shared_ptr result_buffer = + writer_client->OnWriterMessageSync(JByteArrayToBuffer(env, bytes)); + jbyteArray arr = env->NewByteArray(result_buffer->Size()); + env->SetByteArrayRegion(arr, 0, result_buffer->Size(), + reinterpret_cast(result_buffer->Data())); + return arr; +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *reader_client = reinterpret_cast(ptr); + reader_client->OnReaderMessage(JByteArrayToBuffer(env, bytes)); +} + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *reader_client = reinterpret_cast(ptr); + auto result_buffer = reader_client->OnReaderMessageSync(JByteArrayToBuffer(env, bytes)); + + jbyteArray arr = env->NewByteArray(result_buffer->Size()); + env->SetByteArrayRegion(arr, 0, result_buffer->Size(), + reinterpret_cast(result_buffer->Data())); + return arr; +} diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h new file mode 100644 index 00000000..4e5c826f --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h @@ -0,0 +1,81 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_TransferHandler */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_TransferHandler +#define _Included_io_ray_streaming_runtime_transfer_TransferHandler +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: createWriterClientNative + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *, + jobject); + +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: createReaderClientNative + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *, + jobject); + +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: handleWriterMessageNative + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( + JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: handleWriterMessageSyncNative + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( + JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: handleReaderMessageNative + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative( + JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: io_ray_streaming_runtime_transfer_TransferHandler + * Method: handleReaderMessageSyncNative + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( + JNIEnv *, jobject, jlong, jbyteArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc new file mode 100644 index 00000000..241db2b4 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc @@ -0,0 +1,20 @@ +#include "io_ray_streaming_runtime_transfer_channel_ChannelId.h" +#include "streaming_jni_common.h" + +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId( + JNIEnv *env, jclass cls, jlong qid_address) { + auto id = ray::ObjectID::FromBinary( + std::string(reinterpret_cast(qid_address), ray::ObjectID::Size())); + return reinterpret_cast(new ray::ObjectID(id)); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId( + JNIEnv *env, jclass cls, jlong native_id_ptr) { + auto id = reinterpret_cast(native_id_ptr); + STREAMING_CHECK(id != nullptr); + delete id; +} diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h new file mode 100644 index 00000000..ab1295af --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h @@ -0,0 +1,47 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_channel_ChannelId */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_channel_ChannelId +#define _Included_io_ray_streaming_runtime_transfer_channel_ChannelId +#ifdef __cplusplus +extern "C" { +#endif +#undef io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH +#define io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH 20L +/* + * Class: io_ray_streaming_runtime_transfer_channel_ChannelId + * Method: createNativeId + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId(JNIEnv *, jclass, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_channel_ChannelId + * Method: destroyNativeId + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId(JNIEnv *, jclass, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/streaming_jni_common.cc b/streaming/src/lib/java/streaming_jni_common.cc new file mode 100644 index 00000000..226c5502 --- /dev/null +++ b/streaming/src/lib/java/streaming_jni_common.cc @@ -0,0 +1,192 @@ +#include "streaming_jni_common.h" + +std::vector jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) { + int stringCount = env->GetArrayLength(jarr); + std::vector object_id_vec; + for (int i = 0; i < stringCount; i++) { + auto jstr = (jbyteArray)(env->GetObjectArrayElement(jarr, i)); + UniqueIdFromJByteArray idFromJByteArray(env, jstr); + object_id_vec.push_back(idFromJByteArray.PID); + } + return object_id_vec; +} + +std::vector jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) { + int count = env->GetArrayLength(jarr); + std::vector actor_id_vec; + for (int i = 0; i < count; i++) { + auto bytes = (jbyteArray)(env->GetObjectArrayElement(jarr, i)); + std::string id_str(ray::ActorID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ray::ActorID::Size(), + reinterpret_cast(&id_str.front())); + actor_id_vec.push_back(ActorID::FromBinary(id_str)); + } + + return actor_id_vec; +} + +jint throwRuntimeException(JNIEnv *env, const char *message) { + jclass exClass; + char className[] = "java/lang/RuntimeException"; + exClass = env->FindClass(className); + return env->ThrowNew(exClass, message); +} + +jint throwChannelInitException(JNIEnv *env, const char *message, + const std::vector &abnormal_queues) { + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_constructor = env->GetMethodID(array_list_class, "", "()V"); + jmethodID array_list_add = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + jobject array_list = env->NewObject(array_list_class, array_list_constructor); + + for (auto &q_id : abnormal_queues) { + jbyteArray jbyte_array = env->NewByteArray(kUniqueIDSize); + env->SetByteArrayRegion( + jbyte_array, 0, kUniqueIDSize, + const_cast(reinterpret_cast(q_id.Data()))); + env->CallBooleanMethod(array_list, array_list_add, jbyte_array); + } + + jclass ex_class = + env->FindClass("io/ray/streaming/runtime/transfer/ChannelInitException"); + jmethodID ex_constructor = + env->GetMethodID(ex_class, "", "(Ljava/lang/String;Ljava/util/List;)V"); + jstring message_jstr = env->NewStringUTF(message); + jobject ex_obj = env->NewObject(ex_class, ex_constructor, message_jstr, array_list); + env->DeleteLocalRef(message_jstr); + return env->Throw((jthrowable)ex_obj); +} + +jint throwChannelInterruptException(JNIEnv *env, const char *message) { + jclass ex_class = + env->FindClass("io/ray/streaming/runtime/transfer/ChannelInterruptException"); + return env->ThrowNew(ex_class, message); +} + +jclass LoadClass(JNIEnv *env, const char *class_name) { + jclass tempLocalClassRef = env->FindClass(class_name); + jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef); + STREAMING_CHECK(ret) << "Can't load Java class " << class_name; + env->DeleteLocalRef(tempLocalClassRef); + return ret; +} + +template +void JavaListToNativeVector(JNIEnv *env, jobject java_list, + std::vector *native_vector, + std::function element_converter) { + jclass java_list_class = LoadClass(env, "java/util/List"); + jmethodID java_list_size = env->GetMethodID(java_list_class, "size", "()I"); + jmethodID java_list_get = + env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); + int size = env->CallIntMethod(java_list, java_list_size); + native_vector->clear(); + for (int i = 0; i < size; i++) { + native_vector->emplace_back( + element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i))); + } +} + +/// Convert a Java byte array to a C++ UniqueID. +template +inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { + std::string id_str(ID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ID::Size(), + reinterpret_cast(&id_str.front())); + return ID::FromBinary(id_str); +} + +/// Convert a Java String to C++ std::string. +std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) { + const char *c_str = env->GetStringUTFChars(jstr, nullptr); + std::string result(c_str); + env->ReleaseStringUTFChars(static_cast(jstr), c_str); + return result; +} + +/// Convert a Java List to C++ std::vector. +void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, + std::vector *native_vector) { + JavaListToNativeVector( + env, java_list, native_vector, [](JNIEnv *env, jobject jstr) { + return JavaStringToNativeString(env, static_cast(jstr)); + }); +} + +std::shared_ptr FunctionDescriptorToRayFunction(JNIEnv *env, + jobject functionDescriptor) { + jclass java_language_class = LoadClass(env, "io/ray/runtime/generated/Common$Language"); + jclass java_function_descriptor_class = + LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor"); + jmethodID java_language_get_number = + env->GetMethodID(java_language_class, "getNumber", "()I"); + jmethodID java_function_descriptor_get_language = + env->GetMethodID(java_function_descriptor_class, "getLanguage", + "()Lio/ray/runtime/generated/Common$Language;"); + jobject java_language = + env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); + auto language = static_cast<::Language>( + env->CallIntMethod(java_language, java_language_get_number)); + std::vector function_descriptor_list; + jmethodID java_function_descriptor_to_list = + env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;"); + JavaStringListToNativeStringVector( + env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list), + &function_descriptor_list); + ray::FunctionDescriptor function_descriptor = + ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); + RayFunction ray_function(language, function_descriptor); + return std::make_shared(ray_function); +} + +void ParseChannelInitParameters( + JNIEnv *env, jobject param_obj, + std::vector ¶meter_vec) { + jclass java_streaming_queue_initial_parameters_class = + LoadClass(env, + "io/ray/streaming/runtime/transfer/" + "ChannelCreationParametersBuilder"); + jmethodID java_streaming_queue_initial_parameters_getParameters_method = + env->GetMethodID(java_streaming_queue_initial_parameters_class, "getParameters", + "()Ljava/util/List;"); + STREAMING_CHECK(java_streaming_queue_initial_parameters_getParameters_method != + nullptr); + jclass java_streaming_queue_initial_parameters_parameter_class = + LoadClass(env, + "io/ray/streaming/runtime/transfer/" + "ChannelCreationParametersBuilder$Parameter"); + jmethodID java_getActorIdBytes_method = env->GetMethodID( + java_streaming_queue_initial_parameters_parameter_class, "getActorIdBytes", "()[B"); + jmethodID java_getAsyncFunctionDescriptor_method = + env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class, + "getAsyncFunctionDescriptor", + "()Lio/ray/runtime/functionmanager/FunctionDescriptor;"); + jmethodID java_getSyncFunctionDescriptor_method = + env->GetMethodID(java_streaming_queue_initial_parameters_parameter_class, + "getSyncFunctionDescriptor", + "()Lio/ray/runtime/functionmanager/FunctionDescriptor;"); + // Call getParameters method + jobject parameter_list = env->CallObjectMethod( + param_obj, java_streaming_queue_initial_parameters_getParameters_method); + + JavaListToNativeVector( + env, parameter_list, ¶meter_vec, + [java_getActorIdBytes_method, java_getAsyncFunctionDescriptor_method, + java_getSyncFunctionDescriptor_method](JNIEnv *env, jobject jobject_parameter) { + ray::streaming::ChannelCreationParameter native_parameter; + jbyteArray jobject_actor_id_bytes = (jbyteArray)env->CallObjectMethod( + jobject_parameter, java_getActorIdBytes_method); + native_parameter.actor_id = + JavaByteArrayToId(env, jobject_actor_id_bytes); + jobject jobject_async_func = env->CallObjectMethod( + jobject_parameter, java_getAsyncFunctionDescriptor_method); + native_parameter.async_function = + FunctionDescriptorToRayFunction(env, jobject_async_func); + jobject jobject_sync_func = env->CallObjectMethod( + jobject_parameter, java_getSyncFunctionDescriptor_method); + native_parameter.sync_function = + FunctionDescriptorToRayFunction(env, jobject_sync_func); + return native_parameter; + }); +} diff --git a/streaming/src/lib/java/streaming_jni_common.h b/streaming/src/lib/java/streaming_jni_common.h new file mode 100644 index 00000000..dfd6e47d --- /dev/null +++ b/streaming/src/lib/java/streaming_jni_common.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +#include + +#include "channel/channel.h" +#include "ray/core_worker/common.h" +#include "util/streaming_logging.h" + +using namespace ray::core; + +class UniqueIdFromJByteArray { + private: + JNIEnv *_env; + jbyteArray _bytes; + jbyte *b; + + public: + ray::ObjectID PID; + + UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) { + _env = env; + _bytes = wid; + + b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); + PID = ray::ObjectID::FromBinary( + std::string(reinterpret_cast(b), ray::ObjectID::Size())); + } + + ~UniqueIdFromJByteArray() { _env->ReleaseByteArrayElements(_bytes, b, 0); } +}; + +class RawDataFromJByteArray { + private: + JNIEnv *_env; + jbyteArray _bytes; + + public: + uint8_t *data; + uint32_t data_size; + + RawDataFromJByteArray(JNIEnv *env, jbyteArray bytes) { + _env = env; + _bytes = bytes; + data_size = _env->GetArrayLength(_bytes); + jbyte *b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); + data = reinterpret_cast(b); + } + + ~RawDataFromJByteArray() { + _env->ReleaseByteArrayElements(_bytes, reinterpret_cast(data), 0); + } +}; + +class StringFromJString { + private: + JNIEnv *_env; + const char *j_str; + jstring jni_str; + + public: + std::string str; + + StringFromJString(JNIEnv *env, jstring jni_str_) { + jni_str = jni_str_; + _env = env; + j_str = env->GetStringUTFChars(jni_str, nullptr); + str = std::string(j_str); + } + + ~StringFromJString() { _env->ReleaseStringUTFChars(jni_str, j_str); } +}; + +class LongVectorFromJLongArray { + private: + JNIEnv *_env; + jlongArray long_array; + jlong *long_array_ptr = nullptr; + + public: + std::vector data; + + LongVectorFromJLongArray(JNIEnv *env, jlongArray long_array_) { + _env = env; + long_array = long_array_; + + long_array_ptr = env->GetLongArrayElements(long_array, nullptr); + jsize seq_id_size = env->GetArrayLength(long_array); + data = std::vector(long_array_ptr, long_array_ptr + seq_id_size); + } + + ~LongVectorFromJLongArray() { + _env->ReleaseLongArrayElements(long_array, long_array_ptr, 0); + } +}; + +std::vector jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr); +std::vector jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr); + +jint throwRuntimeException(JNIEnv *env, const char *message); +jint throwChannelInitException(JNIEnv *env, const char *message, + const std::vector &abnormal_queues); +jint throwChannelInterruptException(JNIEnv *env, const char *message); +std::shared_ptr FunctionDescriptorToRayFunction(JNIEnv *env, + jobject functionDescriptor); +void ParseChannelInitParameters( + JNIEnv *env, jobject param_obj, + std::vector ¶meter_vec); diff --git a/streaming/src/message/message.cc b/streaming/src/message/message.cc new file mode 100644 index 00000000..3216e3b3 --- /dev/null +++ b/streaming/src/message/message.cc @@ -0,0 +1,101 @@ +#include "message/message.h" + +#include +#include +#include + +#include "ray/common/status.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +StreamingMessage::StreamingMessage(std::shared_ptr &payload_data, + uint32_t payload_size, uint64_t msg_id, + StreamingMessageType message_type) + : payload_(payload_data), + payload_size_(payload_size), + message_type_(message_type), + message_id_(msg_id) {} + +StreamingMessage::StreamingMessage(std::shared_ptr &&payload_data, + uint32_t payload_size, uint64_t msg_id, + StreamingMessageType message_type) + : payload_(payload_data), + payload_size_(payload_size), + message_type_(message_type), + message_id_(msg_id) {} + +StreamingMessage::StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type) + : payload_size_(payload_size), message_type_(message_type), message_id_(msg_id) { + payload_.reset(new uint8_t[payload_size], std::default_delete()); + std::memcpy(payload_.get(), payload_data, payload_size); +} + +StreamingMessage::StreamingMessage(const StreamingMessage &msg) { + payload_size_ = msg.payload_size_; + payload_ = msg.payload_; + message_id_ = msg.message_id_; + message_type_ = msg.message_type_; +} + +StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, + bool verifer_check) { + uint32_t byte_offset = 0; + uint32_t data_size = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(data_size); + + uint64_t msg_id = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(msg_id); + + StreamingMessageType msg_type = + *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(msg_type); + + auto buf = new uint8_t[data_size]; + std::memcpy(buf, bytes + byte_offset, data_size); + auto data_ptr = std::shared_ptr(buf, std::default_delete()); + return std::make_shared(data_ptr, data_size, msg_id, msg_type); +} + +void StreamingMessage::ToBytes(uint8_t *serlizable_data) { + uint32_t byte_offset = 0; + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&payload_size_), + sizeof(payload_size_)); + byte_offset += sizeof(payload_size_); + + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_id_), + sizeof(message_id_)); + byte_offset += sizeof(message_id_); + + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_type_), + sizeof(message_type_)); + byte_offset += sizeof(message_type_); + + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(payload_.get()), + payload_size_); + + byte_offset += payload_size_; + + STREAMING_CHECK(byte_offset == this->ClassBytesSize()); +} + +bool StreamingMessage::operator==(const StreamingMessage &message) const { + return PayloadSize() == message.PayloadSize() && + GetMessageId() == message.GetMessageId() && + GetMessageType() == message.GetMessageType() && + !std::memcmp(Payload(), message.Payload(), PayloadSize()); +} + +std::ostream &operator<<(std::ostream &os, const StreamingMessage &message) { + os << "{" + << " message_type_: " << static_cast(message.GetMessageType()) + << " message_id_: " << message.GetMessageId() + << " payload_size_: " << message.payload_size_ + << " payload_: " << (void *)message.payload_.get() << "}"; + return os; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/message.h b/streaming/src/message/message.h new file mode 100644 index 00000000..3052e9aa --- /dev/null +++ b/streaming/src/message/message.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include + +namespace ray { +namespace streaming { + +class StreamingMessage; + +typedef std::shared_ptr StreamingMessagePtr; + +enum class StreamingMessageType : uint32_t { + Barrier = 1, + Message = 2, + MIN = Barrier, + MAX = Message +}; + +enum class StreamingBarrierType : uint32_t { GlobalBarrier = 0 }; + +struct StreamingBarrierHeader { + StreamingBarrierType barrier_type; + uint64_t barrier_id; + StreamingBarrierHeader() = default; + StreamingBarrierHeader(StreamingBarrierType barrier_type, uint64_t barrier_id) { + this->barrier_type = barrier_type; + this->barrier_id = barrier_id; + } + inline bool IsGlobalBarrier() { + return StreamingBarrierType::GlobalBarrier == barrier_type; + } +}; + +constexpr uint32_t kMessageHeaderSize = + sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType); + +constexpr uint32_t kBarrierHeaderSize = sizeof(StreamingBarrierType) + sizeof(uint64_t); + +/// All messages should be wrapped by this protocol. +// DataSize means length of raw data, message id is increasing from [1, +INF]. +// MessageType will be used for barrier transporting and checkpoint. +/// +----------------+ +/// | PayloadSize=U32| +/// +----------------+ +/// | MessageId=U64 | +/// +----------------+ +/// | MessageType=U32| +/// +----------------+ +/// | Payload=var | +/// +----------------+ +/// Payload field contains barrier header and carried buffer if message type is +/// global/partial barrier. +/// +/// Barrier's Payload field: +/// +----------------------------+ +/// | StreamingBarrierType=U32 | +/// +----------------------------+ +/// | barrier_id=U64 | +/// +----------------------------+ +/// | carried_buffer=var | +/// +----------------------------+ + +class StreamingMessage { + private: + std::shared_ptr payload_; + uint32_t payload_size_; + StreamingMessageType message_type_; + uint64_t message_id_; + + public: + /// Copy raw data from outside shared buffer. + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id + /// \param message_type + StreamingMessage(std::shared_ptr &payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type); + + /// Move outsite raw data to message data. + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id + /// \param message_type + StreamingMessage(std::shared_ptr &&payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type); + + /// Copy raw data from outside buffer. + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id + /// \param message_type + StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, uint64_t msg_id, + StreamingMessageType message_type); + + StreamingMessage(const StreamingMessage &); + + StreamingMessage operator=(const StreamingMessage &) = delete; + + virtual ~StreamingMessage() = default; + + inline StreamingMessageType GetMessageType() const { return message_type_; } + inline uint64_t GetMessageId() const { return message_id_; } + + inline uint8_t *Payload() const { return payload_.get(); } + + inline uint32_t PayloadSize() const { return payload_size_; } + + inline bool IsMessage() { return StreamingMessageType::Message == message_type_; } + inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; } + + bool operator==(const StreamingMessage &) const; + + static inline std::shared_ptr MakeBarrierPayload( + StreamingBarrierHeader &barrier_header, const uint8_t *data, uint32_t data_size) { + std::shared_ptr ptr(new uint8_t[data_size + kBarrierHeaderSize], + std::default_delete()); + std::memcpy(ptr.get(), &barrier_header.barrier_type, sizeof(StreamingBarrierType)); + std::memcpy(ptr.get() + sizeof(StreamingBarrierType), &barrier_header.barrier_id, + sizeof(uint64_t)); + if (data && data_size > 0) { + std::memcpy(ptr.get() + kBarrierHeaderSize, data, data_size); + } + return ptr; + } + + virtual void ToBytes(uint8_t *data); + static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true); + + inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + payload_size_; } + + static inline void GetBarrierIdFromRawData(const uint8_t *data, + StreamingBarrierHeader *barrier_header) { + barrier_header->barrier_type = *reinterpret_cast(data); + barrier_header->barrier_id = + *reinterpret_cast(data + sizeof(StreamingBarrierType)); + } + + friend std::ostream &operator<<(std::ostream &os, const StreamingMessage &message); +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/message_bundle.cc b/streaming/src/message/message_bundle.cc new file mode 100644 index 00000000..629a0613 --- /dev/null +++ b/streaming/src/message/message_bundle.cc @@ -0,0 +1,208 @@ +#include "message/message_bundle.h" + +#include +#include + +#include "config/streaming_config.h" +#include "ray/common/status.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { +StreamingMessageBundle::StreamingMessageBundle(uint64_t last_offset_seq_id, + uint64_t message_bundle_ts) + : StreamingMessageBundleMeta(message_bundle_ts, last_offset_seq_id, 0, + StreamingMessageBundleType::Empty) { + this->raw_bundle_size_ = 0; +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta(const uint8_t *bytes) { + std::memcpy(GetFirstMemberAddress(), bytes, + kMessageBundleMetaHeaderSize - sizeof(uint32_t)); +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta( + const uint64_t message_bundle_ts, const uint64_t last_offset_seq_id, + const uint32_t message_list_size, const StreamingMessageBundleType bundle_type) + : message_bundle_ts_(message_bundle_ts), + last_message_id_(last_offset_seq_id), + message_list_size_(message_list_size), + bundle_type_(bundle_type) { + STREAMING_CHECK(message_list_size <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); +} + +void StreamingMessageBundleMeta::ToBytes(uint8_t *bytes) { + uint32_t magicNum = StreamingMessageBundleMeta::StreamingMessageBundleMagicNum; + std::memcpy(bytes, reinterpret_cast(&magicNum), sizeof(uint32_t)); + std::memcpy(bytes + sizeof(uint32_t), GetFirstMemberAddress(), + kMessageBundleMetaHeaderSize - sizeof(uint32_t)); +} + +StreamingMessageBundleMetaPtr StreamingMessageBundleMeta::FromBytes(const uint8_t *bytes, + bool check) { + STREAMING_CHECK(bytes); + + uint32_t byte_offset = 0; + STREAMING_CHECK(CheckBundleMagicNum(bytes)); + byte_offset += sizeof(uint32_t); + + auto result = std::make_shared(bytes + byte_offset); + STREAMING_CHECK(result->GetMessageListSize() <= + StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); + return result; +} + +bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta &meta) const { + return this->message_list_size_ == meta.GetMessageListSize() && + this->message_bundle_ts_ == meta.GetMessageBundleTs() && + this->bundle_type_ == meta.GetBundleType() && + this->last_message_id_ == meta.GetLastMessageId(); +} + +bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) const { + return operator==(*meta); +} + +std::ostream &operator<<(std::ostream &os, const StreamingMessageBundleMeta &meta) { + os << "{" + << "last_message_id_: " << meta.last_message_id_ + << ", message_list_size_: " << meta.message_list_size_ + << ", bundle_type_: " << static_cast(meta.bundle_type_) << "}"; + return os; +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta() + : bundle_type_(StreamingMessageBundleType::Empty) {} + +StreamingMessageBundle::StreamingMessageBundle( + std::list &&message_list, uint64_t message_ts, + uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type, + uint32_t raw_data_size) + : StreamingMessageBundleMeta(message_ts, last_offset_seq_id, message_list.size(), + bundle_type), + raw_bundle_size_(raw_data_size), + message_list_(message_list) { + if (bundle_type_ != StreamingMessageBundleType::Empty) { + if (!raw_bundle_size_) { + raw_bundle_size_ = std::accumulate( + message_list_.begin(), message_list_.end(), 0, + [](uint32_t x, StreamingMessagePtr &y) { return x + y->ClassBytesSize(); }); + } + } +} + +StreamingMessageBundle::StreamingMessageBundle( + std::list &message_list, uint64_t message_ts, + uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type, + uint32_t raw_data_size) + : StreamingMessageBundle(std::list(message_list), message_ts, + last_offset_seq_id, bundle_type, raw_data_size) {} + +StreamingMessageBundle::StreamingMessageBundle(StreamingMessageBundle &bundle) { + message_bundle_ts_ = bundle.message_bundle_ts_; + message_list_size_ = bundle.message_list_size_; + raw_bundle_size_ = bundle.raw_bundle_size_; + bundle_type_ = bundle.bundle_type_; + last_message_id_ = bundle.last_message_id_; + message_list_ = bundle.message_list_; +} + +void StreamingMessageBundle::ToBytes(uint8_t *bytes) { + uint32_t byte_offset = 0; + StreamingMessageBundleMeta::ToBytes(bytes + byte_offset); + + byte_offset += StreamingMessageBundleMeta::ClassBytesSize(); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&raw_bundle_size_), + sizeof(uint32_t)); + byte_offset += sizeof(uint32_t); + + if (raw_bundle_size_ > 0) { + ConvertMessageListToRawData(message_list_, raw_bundle_size_, bytes + byte_offset); + } +} + +StreamingMessageBundlePtr StreamingMessageBundle::FromBytes(const uint8_t *bytes, + bool verifer_check) { + uint32_t byte_offset = 0; + StreamingMessageBundleMetaPtr meta_ptr = + StreamingMessageBundleMeta::FromBytes(bytes + byte_offset); + byte_offset += meta_ptr->ClassBytesSize(); + + uint32_t raw_data_size = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(uint32_t); + + std::list message_list; + // only message bundle own raw data + if (meta_ptr->GetBundleType() != StreamingMessageBundleType::Empty) { + GetMessageListFromRawData(bytes + byte_offset, raw_data_size, + meta_ptr->GetMessageListSize(), message_list); + byte_offset += raw_data_size; + } + auto result = std::make_shared( + message_list, meta_ptr->GetMessageBundleTs(), meta_ptr->GetLastMessageId(), + meta_ptr->GetBundleType()); + STREAMING_CHECK(byte_offset == result->ClassBytesSize()); + return result; +} + +void StreamingMessageBundle::GetMessageListFromRawData( + const uint8_t *bytes, uint32_t byte_size, uint32_t message_list_size, + std::list &message_list) { + uint32_t byte_offset = 0; + // only message bundle own raw data + for (size_t i = 0; i < message_list_size; ++i) { + StreamingMessagePtr item = StreamingMessage::FromBytes(bytes + byte_offset); + message_list.push_back(item); + byte_offset += item->ClassBytesSize(); + } + STREAMING_CHECK(byte_offset == byte_size); +} + +void StreamingMessageBundle::GetMessageList( + std::list &message_list) { + message_list = message_list_; +} + +void StreamingMessageBundle::ConvertMessageListToRawData( + const std::list &message_list, uint32_t raw_data_size, + uint8_t *raw_data) { + uint32_t byte_offset = 0; + for (auto &message : message_list) { + message->ToBytes(raw_data + byte_offset); + byte_offset += message->ClassBytesSize(); + } + STREAMING_CHECK(byte_offset == raw_data_size); +} + +bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const { + if (!(StreamingMessageBundleMeta::operator==(&bundle) && + this->GetRawBundleSize() == bundle.GetRawBundleSize() && + this->GetMessageListSize() == bundle.GetMessageListSize())) { + return false; + } + auto it1 = message_list_.begin(); + auto it2 = bundle.message_list_.begin(); + while (it1 != message_list_.end() && it2 != bundle.message_list_.end()) { + if (!((*it1).get()->operator==(*(*it2).get()))) { + return false; + } + it1++; + it2++; + } + return true; +} + +bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const { + return this->operator==(*bundle); +} + +std::ostream &operator<<(std::ostream &os, const DataBundle &bundle) { + os << "{" + << "data: " << (void *)bundle.data << ", data_size: " << bundle.data_size + << ", channel last_barrier_id: " << bundle.last_barrier_id + << ", meta: " << *(bundle.meta) << "}"; + return os; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/message_bundle.h b/streaming/src/message/message_bundle.h new file mode 100644 index 00000000..2cad05e3 --- /dev/null +++ b/streaming/src/message/message_bundle.h @@ -0,0 +1,211 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "message/message.h" +#include "ray/common/id.h" + +namespace ray { +namespace streaming { + +enum class StreamingMessageBundleType : uint32_t { + Empty = 1, + Barrier = 2, + Bundle = 3, + MIN = Empty, + MAX = Bundle +}; + +class StreamingMessageBundleMeta; +class StreamingMessageBundle; + +typedef std::shared_ptr StreamingMessageBundlePtr; +typedef std::shared_ptr StreamingMessageBundleMetaPtr; + +constexpr uint32_t kMessageBundleMetaHeaderSize = sizeof(uint32_t) + sizeof(uint32_t) + + sizeof(uint64_t) + sizeof(uint64_t) + + sizeof(StreamingMessageBundleType); + +constexpr uint32_t kMessageBundleHeaderSize = + kMessageBundleMetaHeaderSize + sizeof(uint32_t); + +class StreamingMessageBundleMeta { + public: + static const uint32_t StreamingMessageBundleMagicNum = 0xCAFEBABA; + + protected: + uint64_t message_bundle_ts_; + + uint64_t last_message_id_; + + uint32_t message_list_size_; + + StreamingMessageBundleType bundle_type_; + + private: + /// To speed up memory copy and serilization, we use memory layout of compiler related + /// member variables. It's must be modified if any field is going to be inserted before + /// first member property. + /// Reference + /// :/http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p1113r0.html#2254). + inline uint8_t *GetFirstMemberAddress() { + return reinterpret_cast(&message_bundle_ts_); + } + + public: + explicit StreamingMessageBundleMeta(const uint8_t *bytes); + + explicit StreamingMessageBundleMeta(const uint64_t message_bunddle_tes, + const uint64_t last_offset_seq_id, + const uint32_t message_list_size, + const StreamingMessageBundleType bundle_type); + + explicit StreamingMessageBundleMeta(StreamingMessageBundleMeta *); + + explicit StreamingMessageBundleMeta(); + + virtual ~StreamingMessageBundleMeta() = default; + + bool operator==(StreamingMessageBundleMeta &) const; + + bool operator==(StreamingMessageBundleMeta *) const; + + inline uint64_t GetMessageBundleTs() const { return message_bundle_ts_; } + + inline uint64_t GetLastMessageId() const { return last_message_id_; } + + inline uint32_t GetMessageListSize() const { return message_list_size_; } + + inline StreamingMessageBundleType GetBundleType() const { return bundle_type_; } + + inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; } + inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; } + inline bool IsEmptyMsg() { return StreamingMessageBundleType::Empty == bundle_type_; } + + virtual void ToBytes(uint8_t *data); + static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data, + bool verifer_check = true); + inline virtual uint32_t ClassBytesSize() { return kMessageBundleMetaHeaderSize; } + + inline static bool CheckBundleMagicNum(const uint8_t *bytes) { + const uint32_t *magic_num = reinterpret_cast(bytes); + return *magic_num == StreamingMessageBundleMagicNum; + } + + std::string ToString() { + return std::to_string(last_message_id_) + "," + std::to_string(message_list_size_) + + "," + std::to_string(message_bundle_ts_) + "," + + std::to_string(static_cast(bundle_type_)); + } + + friend std::ostream &operator<<(std::ostream &os, + const StreamingMessageBundleMeta &meta); +}; + +/// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) +/// with the following protocol: MagicNum = 0xcafebaba Timestamp 64bits timestamp +/// (milliseconds from 1970) LastMessageId( the last id of bundle) (0,INF] +/// MessageListSize(bundle len of message) +/// BundleType(a. bundle = 3 , b. barrier =2, c. empty = 1) +/// RawBundleSize(binary length of data) +/// RawData ( binary data) +/// +/// +--------------------+ +/// | MagicNum=U32 | +/// +--------------------+ +/// | BundleTs=U64 | +/// +--------------------+ +/// | LastMessageId=U64 | +/// +--------------------+ +/// | MessageListSize=U32| +/// +--------------------+ +/// | BundleType=U32 | +/// +--------------------+ +/// | RawBundleSize=U32 | +/// +--------------------+ +/// | RawData=var(N*Msg) | +/// +--------------------+ +/// It should be noted that StreamingMessageBundle and StreamingMessageBundleMeta share +/// almost same protocol but the last two fields (RawBundleSize and RawData). +class StreamingMessageBundle : public StreamingMessageBundleMeta { + private: + uint32_t raw_bundle_size_; + + // Lazy serlization/deserlization. + std::list message_list_; + + public: + explicit StreamingMessageBundle(std::list &&message_list, + uint64_t bundle_ts, uint64_t offset, + StreamingMessageBundleType bundle_type, + uint32_t raw_data_size = 0); + + // Duplicated copy if left reference in constructor. + explicit StreamingMessageBundle(std::list &message_list, + uint64_t bundle_ts, uint64_t offset, + StreamingMessageBundleType bundle_type, + uint32_t raw_data_size = 0); + + // New a empty bundle by passing last message id and timestamp. + explicit StreamingMessageBundle(uint64_t, uint64_t); + + explicit StreamingMessageBundle(StreamingMessageBundle &bundle); + + virtual ~StreamingMessageBundle() = default; + + inline uint32_t GetRawBundleSize() const { return raw_bundle_size_; } + + bool operator==(StreamingMessageBundle &bundle) const; + + bool operator==(StreamingMessageBundle *bundle_ptr) const; + + void GetMessageList(std::list &message_list); + + const std::list &GetMessageList() const { return message_list_; } + + virtual void ToBytes(uint8_t *data); + static StreamingMessageBundlePtr FromBytes(const uint8_t *data, + bool verifer_check = true); + inline virtual uint32_t ClassBytesSize() { + return kMessageBundleHeaderSize + raw_bundle_size_; + }; + + static void GetMessageListFromRawData(const uint8_t *bytes, uint32_t bytes_size, + uint32_t message_list_size, + std::list &message_list); + + static void ConvertMessageListToRawData( + const std::list &message_list, uint32_t raw_data_size, + uint8_t *raw_data); +}; + +/// Databundle is super-bundle that contains channel information (upstream +/// channel id & bundle meta data) and raw buffer pointer. +struct DataBundle { + uint8_t *data = nullptr; + uint32_t data_size; + ObjectID from; + uint32_t last_barrier_id; + StreamingMessageBundleMetaPtr meta; + bool is_reallocated = false; + + ~DataBundle() { + if (is_reallocated) { + delete[] data; + } + } + + void Realloc(uint32_t size) { + data = new uint8_t[size]; + is_reallocated = true; + } + + friend std::ostream &operator<<(std::ostream &os, const DataBundle &bundle); +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/priority_queue.h b/streaming/src/message/priority_queue.h new file mode 100644 index 00000000..6d53bd24 --- /dev/null +++ b/streaming/src/message/priority_queue.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include + +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +template + +class PriorityQueue { + private: + std::vector merge_vec_; + C comparator_; + + public: + PriorityQueue(C &comparator) : comparator_(comparator){}; + + inline void push(T &&item) { + merge_vec_.push_back(std::forward(item)); + std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline void push(const T &item) { + merge_vec_.push_back(item); + std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline void pop() { + STREAMING_CHECK(!isEmpty()); + std::pop_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + merge_vec_.pop_back(); + } + + inline void makeHeap() { + std::make_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline T &top() { return merge_vec_.front(); } + + inline uint32_t size() { return merge_vec_.size(); } + + inline bool isEmpty() { return merge_vec_.empty(); } + + std::vector &getRawVector() { return merge_vec_; } +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/metrics/stats_reporter.cc b/streaming/src/metrics/stats_reporter.cc new file mode 100644 index 00000000..286b8a38 --- /dev/null +++ b/streaming/src/metrics/stats_reporter.cc @@ -0,0 +1,117 @@ +#include "metrics/stats_reporter.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +std::shared_ptr StatsReporter::GetMetricByName( + const std::string &metric_name) { + std::unique_lock lock(metric_mutex_); + auto metric = metric_map_.find(metric_name); + if (metric != metric_map_.end()) { + return metric->second; + } + return nullptr; +} + +void StatsReporter::MetricRegister(const std::string &metric_name, + std::shared_ptr metric) { + std::unique_lock lock(metric_mutex_); + metric_map_[metric_name] = metric; +} + +void StatsReporter::UnregisterAllMetrics() { + std::unique_lock lock(metric_mutex_); + metric_map_.clear(); +} + +bool StatsReporter::Start(const StreamingMetricsConfig &conf) { + global_tags_ = conf.GetMetricsGlobalTags(); + service_name_ = conf.GetMetricsServiceName(); + STREAMING_LOG(INFO) << "Start stats reporter, service name " << service_name_ + << ", global tags size : " << global_tags_.size() + << ", stats disabled : " + << stats::StatsConfig::instance().IsStatsDisabled(); + for (auto &tag : global_tags_) { + global_tag_key_list_.push_back(stats::TagKeyType::Register(tag.first)); + } + return true; +} + +bool StatsReporter::Start(const std::string &json_string) { return true; } + +StatsReporter::~StatsReporter() { + STREAMING_LOG(WARNING) << "stats client shutdown"; + Shutdown(); +}; + +void StatsReporter::Shutdown() { UnregisterAllMetrics(); } + +void StatsReporter::UpdateCounter(const std::string &domain, + const std::string &group_name, + const std::string &short_name, double value) { + const std::string merged_metric_name = + METRIC_GROUP_JOIN(domain, group_name, short_name); +} + +void StatsReporter::UpdateCounter( + const std::string &metric_name, + const std::unordered_map &tags, double value) { + STREAMING_LOG(DEBUG) << "Report counter metric " << metric_name << " , value " << value; +} + +void StatsReporter::UpdateGauge(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, + bool is_reset) { + const std::string merged_metric_name = + service_name_ + "." + METRIC_GROUP_JOIN(domain, group_name, short_name); + STREAMING_LOG(DEBUG) << "Report gauge metric " << merged_metric_name << " , value " + << value; + auto metric = GetMetricByName(merged_metric_name); + if (nullptr == metric) { + metric = std::shared_ptr( + new ray::stats::Gauge(merged_metric_name, "", "", global_tag_key_list_)); + MetricRegister(merged_metric_name, metric); + } + metric->Record(value, global_tags_); +} + +void StatsReporter::UpdateGauge(const std::string &metric_name, + const std::unordered_map &tags, + double value, bool is_reset) { + const std::string merged_metric_name = service_name_ + "." + metric_name; + STREAMING_LOG(DEBUG) << "Report gauge metric " << merged_metric_name << " , value " + << value; + // Get metric from registered map, create a new one item if no such metric can be found + // in map. + auto metric = GetMetricByName(metric_name); + if (nullptr == metric) { + // Register tag key for all tags. + std::vector tag_key_list(global_tag_key_list_.begin(), + global_tag_key_list_.end()); + for (auto &tag : tags) { + tag_key_list.push_back(stats::TagKeyType::Register(tag.first)); + } + metric = std::shared_ptr( + new ray::stats::Gauge(merged_metric_name, "", "", tag_key_list)); + MetricRegister(merged_metric_name, metric); + } + auto merged_tags = MergeGlobalTags(tags); + metric->Record(value, merged_tags); +} + +void StatsReporter::UpdateHistogram(const std::string &domain, + const std::string &group_name, + const std::string &short_name, double value, + double min_value, double max_value) {} + +void StatsReporter::UpdateHistogram( + const std::string &metric_name, + const std::unordered_map &tags, double value, + double min_value, double max_value) {} + +void StatsReporter::UpdateQPS(const std::string &metric_name, + const std::unordered_map &tags, + double value) {} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/metrics/stats_reporter.h b/streaming/src/metrics/stats_reporter.h new file mode 100644 index 00000000..2577a1fb --- /dev/null +++ b/streaming/src/metrics/stats_reporter.h @@ -0,0 +1,88 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "ray/stats/metric.h" +#include "streaming_perf_metric.h" + +namespace ray { +namespace streaming { + +class StatsReporter : public StreamingReporterInterface { + public: + virtual ~StatsReporter(); + + bool Start(const StreamingMetricsConfig &conf) override; + + bool Start(const std::string &json_string); + + void Shutdown() override; + + void UpdateCounter(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value) override; + + void UpdateGauge(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, + bool is_reset = true) override; + + void UpdateHistogram(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, double min_value, + double max_value) override; + + void UpdateCounter(const std::string &metric_name, + const std::unordered_map &tags, + double value) override; + + void UpdateGauge(const std::string &metric_name, + const std::unordered_map &tags, double value, + bool is_rest = true) override; + + void UpdateHistogram(const std::string &metric_name, + const std::unordered_map &tags, + double value, double min_value, double max_value) override; + + void UpdateQPS(const std::string &metric_name, + const std::unordered_map &tags, + double value) override; + + protected: + std::shared_ptr GetMetricByName(const std::string &metric_name); + void MetricRegister(const std::string &metric_name, + std::shared_ptr metric); + void UnregisterAllMetrics(); + + private: + inline std::unordered_map MergeGlobalTags( + const std::unordered_map &tags) { + std::unordered_map merged_tags; + merged_tags.insert(global_tags_.begin(), global_tags_.end()); + for (auto &item : tags) { + merged_tags.emplace(item.first, item.second); + } + return merged_tags; + } + + private: + std::mutex metric_mutex_; + std::unordered_map> metric_map_; + std::unordered_map global_tags_; + std::vector global_tag_key_list_; + std::string service_name_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/metrics/streaming_perf_metric.cc b/streaming/src/metrics/streaming_perf_metric.cc new file mode 100644 index 00000000..c8a48191 --- /dev/null +++ b/streaming/src/metrics/streaming_perf_metric.cc @@ -0,0 +1,97 @@ +#include + +#include "metrics/stats_reporter.h" +#include "metrics/streaming_perf_metric.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +bool StreamingReporter::Start(const StreamingMetricsConfig &conf) { + if (impl_) { + STREAMING_LOG(WARNING) << "Streaming perf is active"; + } else { + impl_.reset(new StatsReporter()); + return impl_->Start(conf); + } + return false; +} + +void StreamingReporter::Shutdown() { + if (impl_) { + impl_->Shutdown(); + impl_.reset(); + } else { + STREAMING_LOG(WARNING) << "No active perf instance will be shutdown"; + } +} +void StreamingReporter::UpdateCounter(const std::string &domain, + const std::string &group_name, + const std::string &short_name, double value) { + if (impl_) { + impl_->UpdateCounter(domain, group_name, short_name, value); + } else { + STREAMING_LOG(WARNING) << "No active perf instance"; + } +} + +void StreamingReporter::UpdateGauge(const std::string &domain, + const std::string &group_name, + const std::string &short_name, double value, + bool is_reset) { + if (impl_) { + impl_->UpdateGauge(domain, group_name, short_name, value, is_reset); + } else { + STREAMING_LOG(WARNING) << "No active perf instance"; + } +} + +void StreamingReporter::UpdateHistogram(const std::string &domain, + const std::string &group_name, + const std::string &short_name, double value, + double min_value, double max_value) { + if (impl_) { + impl_->UpdateHistogram(domain, group_name, short_name, value, min_value, max_value); + } else { + STREAMING_LOG(WARNING) << "No active perf instance"; + } +} +void StreamingReporter::UpdateQPS( + const std::string &metric_name, + const std::unordered_map &tags, double value) { + if (impl_) { + impl_->UpdateQPS(metric_name, tags, value); + } else { + STREAMING_LOG(WARNING) << "No active perf instance"; + } +} + +StreamingReporter::~StreamingReporter() { + if (impl_) { + STREAMING_LOG(INFO) << "Destory streamimg perf => " << impl_.get(); + Shutdown(); + } +} + +void StreamingReporter::UpdateCounter( + const std::string &metric_name, + const std::unordered_map &tags, double value) { + if (impl_) { + impl_->UpdateCounter(metric_name, tags, value); + } +} +void StreamingReporter::UpdateGauge( + const std::string &metric_name, + const std::unordered_map &tags, double value, + bool is_rest) { + if (impl_) { + impl_->UpdateGauge(metric_name, tags, value, is_rest); + } +} +void StreamingReporter::UpdateHistogram( + const std::string &metric_name, + const std::unordered_map &tags, double value, + double min_value, double max_value) {} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/metrics/streaming_perf_metric.h b/streaming/src/metrics/streaming_perf_metric.h new file mode 100644 index 00000000..dbb8c850 --- /dev/null +++ b/streaming/src/metrics/streaming_perf_metric.h @@ -0,0 +1,97 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "config/streaming_config.h" + +namespace ray { +namespace streaming { +#define METRIC_GROUP_JOIN(a, b, c) (a + "." + b + "." + c) + +class StreamingReporterInterface { + public: + virtual ~StreamingReporterInterface() = default; + virtual bool Start(const StreamingMetricsConfig &conf) = 0; + + virtual void Shutdown() = 0; + + virtual void UpdateCounter(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value) = 0; + + virtual void UpdateGauge(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, + bool is_reset) = 0; + + virtual void UpdateHistogram(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, + double min_value, double max_value) = 0; + + virtual void UpdateCounter(const std::string &metric_name, + const std::unordered_map &tags, + double value) = 0; + + virtual void UpdateGauge(const std::string &metric_name, + const std::unordered_map &tags, + double value, bool is_rest) = 0; + + virtual void UpdateHistogram(const std::string &metric_name, + const std::unordered_map &tags, + double value, double min_value, double max_value) = 0; + + virtual void UpdateQPS(const std::string &metric_name, + const std::unordered_map &tags, + double value) = 0; +}; + +/// Streaming perf is a reporter instance based multiple backend. +/// Other modules can report gauge/histogram/counter/qps measurement to meteric server +/// side. +class StreamingReporter : public StreamingReporterInterface { + public: + StreamingReporter(){}; + virtual ~StreamingReporter(); + bool Start(const StreamingMetricsConfig &conf) override; + void Shutdown() override; + void UpdateCounter(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value) override; + void UpdateGauge(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, + bool is_reset = true) override; + + void UpdateHistogram(const std::string &domain, const std::string &group_name, + const std::string &short_name, double value, double min_value, + double max_value) override; + + void UpdateCounter(const std::string &metric_name, + const std::unordered_map &tags, + double value) override; + + void UpdateGauge(const std::string &metric_name, + const std::unordered_map &tags, double value, + bool is_rest = true) override; + + void UpdateHistogram(const std::string &metric_name, + const std::unordered_map &tags, + double value, double min_value, double max_value) override; + + void UpdateQPS(const std::string &metric_name, + const std::unordered_map &tags, + double value) override; + + private: + std::unique_ptr impl_; +}; +} // namespace streaming + +} // namespace ray diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto new file mode 100644 index 00000000..34c2dac7 --- /dev/null +++ b/streaming/src/protobuf/remote_call.proto @@ -0,0 +1,105 @@ +syntax = "proto3"; + +package ray.streaming.proto; + +import "protobuf/streaming.proto"; + +import "google/protobuf/any.proto"; + +option java_package = "io.ray.streaming.runtime.generated"; + +// Execution vertex info, including it's upstream and downstream +message ExecutionVertexContext { + // An edge between 2 execution vertices + message ExecutionEdge { + // upstream execution vertex id + int32 source_execution_vertex_id = 1; + // downstream execution vertex id + int32 target_execution_vertex_id = 2; + // serialized partition between source/target vertex + bytes partition = 3; + } + + message ExecutionVertex { + // unique id of execution vertex + int32 execution_vertex_id = 1; + // unique id of execution job vertex + int32 execution_job_vertex_id = 2; + // name of execution job vertex, e.g. 1-SourceOperator + string execution_job_vertex_name = 3; + // index of execution vertex + int32 execution_vertex_index = 4; + int32 parallelism = 5; + // serialized operator + bytes operator = 6; + bool chained = 7; + bytes worker_actor = 8; + string container_id = 9; + uint64 build_time = 10; + Language language = 11; + map config = 12; + map resource = 13; + } + + // vertices + ExecutionVertex current_execution_vertex = 1; + repeated ExecutionVertex upstream_execution_vertices = 2; + repeated ExecutionVertex downstream_execution_vertices = 3; + + // edges + repeated ExecutionEdge input_execution_edges = 4; + repeated ExecutionEdge output_execution_edges = 5; +} + +// Streaming python worker context +message PythonJobWorkerContext { + // serialized master actor handle + bytes master_actor = 1; + // vertex including it's upstream and downstream + ExecutionVertexContext execution_vertex_context = 2; +} + +message BoolResult { + bool boolRes = 1; +} + +message Barrier { + int64 id = 1; +} + +message CheckpointId { + int64 checkpoint_id = 1; +} + +message BaseWorkerCmd { + bytes actor_id = 1; // actor id + int64 timestamp = 2; + google.protobuf.Any detail = 3; +} + +message WorkerCommitReport { + int64 commit_checkpoint_id = 1; +} + +message WorkerRollbackRequest { + string exception_msg = 1; + string worker_hostname = 2; + string worker_pid = 3; +} + +message CallResult { + bool success = 1; + int32 result_code = 2; + string result_msg = 3; + QueueRecoverInfo result_obj = 4; +} + +message QueueRecoverInfo { + enum QueueCreationStatus { + FreshStarted = 0; + PullOk = 1; + Timeout = 2; + DataLost = 3; + } + map creation_status = 3; +} \ No newline at end of file diff --git a/streaming/src/protobuf/streaming.proto b/streaming/src/protobuf/streaming.proto new file mode 100644 index 00000000..697928e2 --- /dev/null +++ b/streaming/src/protobuf/streaming.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +package ray.streaming.proto; + +option java_package = "io.ray.streaming.runtime.generated"; + +enum Language { + JAVA = 0; + PYTHON = 1; +} + +enum NodeType { + UNKNOWN = 0; + // Sources are where your program reads its input from + SOURCE = 1; + // Transform one or more DataStreams into a new DataStream. + TRANSFORM = 2; + // Sinks consume DataStreams and forward them to files, sockets, external + // systems, or print them. + SINK = 3; +} + +enum ReliabilityLevel { + NONE = 0; + AT_LEAST_ONCE = 1; + EXACTLY_ONCE = 2; +} + +enum FlowControlType { + UNKNOWN_FLOW_CONTROL_TYPE = 0; + UnconsumedSeqFlowControl = 1; + NoFlowControl = 2; +} + +// all string in this message is ASCII string +message StreamingConfig { + string job_name = 1; + string worker_name = 3; + string op_name = 4; + NodeType role = 5; + uint32 ring_buffer_capacity = 6; + uint32 empty_message_interval = 7; + FlowControlType flow_control_type = 8; + uint32 writer_consumed_step = 9; + uint32 reader_consumed_step = 10; + uint32 event_driven_flow_control_interval = 11; +} diff --git a/streaming/src/protobuf/streaming_queue.proto b/streaming/src/protobuf/streaming_queue.proto new file mode 100644 index 00000000..0fb26000 --- /dev/null +++ b/streaming/src/protobuf/streaming_queue.proto @@ -0,0 +1,98 @@ +syntax = "proto3"; + +package ray.streaming.queue.protobuf; + +enum StreamingQueueMessageType { + StreamingQueueDataMsgType = 0; + StreamingQueueCheckMsgType = 1; + StreamingQueueCheckRspMsgType = 2; + StreamingQueueNotificationMsgType = 3; + StreamingQueueTestInitMsgType = 4; + StreamingQueueTestCheckStatusRspMsgType = 5; + StreamingQueuePullRequestMsgType = 6; + StreamingQueuePullResponseMsgType = 7; + StreamingQueueResendDataMsgType = 8; +} + +enum StreamingQueueError { + OK = 0; + QUEUE_NOT_EXIST = 1; + DATA_LOST = 2; + NO_VALID_DATA = 3; +} + +message MessageCommon { + bytes src_actor_id = 1; + bytes dst_actor_id = 2; + bytes queue_id = 3; +} + +message StreamingQueueDataMsg { + MessageCommon common = 1; + uint64 seq_id = 2; + uint64 msg_id_start = 3; + uint64 msg_id_end = 4; + uint64 length = 5; + bool raw = 6; +} + +message StreamingQueueCheckMsg { + MessageCommon common = 1; +} + +message StreamingQueueCheckRspMsg { + MessageCommon common = 1; + StreamingQueueError err_code = 2; +} + +message StreamingQueueNotificationMsg { + MessageCommon common = 1; + uint64 seq_id = 2; +} + +// for test +enum StreamingQueueTestRole { + WRITER = 0; + READER = 1; +} + +message StreamingQueueTestInitMsg { + StreamingQueueTestRole role = 1; + bytes src_actor_id = 2; + bytes dst_actor_id = 3; + bytes actor_handle = 4; + repeated bytes queue_ids = 5; + repeated bytes rescale_queue_ids = 6; + string test_suite_name = 7; + string test_name = 8; + uint64 param = 9; +} + +message StreamingQueueTestCheckStatusRspMsg { + string test_name = 1; + bool status = 2; +} + +message StreamingQueuePullRequestMsg { + MessageCommon common = 1; + uint64 msg_id = 2; +} + +message StreamingQueuePullResponseMsg { + MessageCommon common = 1; + uint64 seq_id = 2; + uint64 msg_id = 3; + StreamingQueueError err_code = 4; + bool is_upstream_first_pull = 5; +} + +message StreamingQueueResendDataMsg { + MessageCommon common = 1; + uint64 first_seq_id = 2; + uint64 last_seq_id = 3; + uint64 seq_id = 4; + uint64 msg_id_start = 5; + uint64 msg_id_end = 6; + uint64 length = 7; + bool raw = 8; +} diff --git a/streaming/src/queue/message.cc b/streaming/src/queue/message.cc new file mode 100644 index 00000000..a74a7809 --- /dev/null +++ b/streaming/src/queue/message.cc @@ -0,0 +1,340 @@ +#include "queue/message.h" + +namespace ray { +namespace streaming { +const uint32_t Message::MagicNum = 0xBABA0510; + +std::unique_ptr Message::ToBytes() { + uint8_t *bytes = nullptr; + + std::string pboutput; + ToProtobuf(&pboutput); + int64_t fbs_length = pboutput.length(); + + queue::protobuf::StreamingQueueMessageType type = Type(); + size_t total_len = kItemHeaderSize + fbs_length; + if (buffer_ != nullptr) { + total_len += buffer_->Size(); + } + bytes = new uint8_t[total_len]; + STREAMING_CHECK(bytes != nullptr) << "allocate bytes fail."; + + uint8_t *p_cur = bytes; + memcpy(p_cur, &Message::MagicNum, sizeof(Message::MagicNum)); + + p_cur += sizeof(Message::MagicNum); + memcpy(p_cur, &type, sizeof(type)); + + p_cur += sizeof(type); + memcpy(p_cur, &fbs_length, sizeof(fbs_length)); + + p_cur += sizeof(fbs_length); + uint8_t *fbs_bytes = (uint8_t *)pboutput.data(); + memcpy(p_cur, fbs_bytes, fbs_length); + p_cur += fbs_length; + + if (buffer_ != nullptr) { + memcpy(p_cur, buffer_->Data(), buffer_->Size()); + } + + // COPY + std::unique_ptr buffer = + std::make_unique(bytes, total_len, true); + delete[] bytes; + return buffer; +} + +void Message::FillMessageCommon(queue::protobuf::MessageCommon *common) { + common->set_src_actor_id(actor_id_.Binary()); + common->set_dst_actor_id(peer_actor_id_.Binary()); + common->set_queue_id(queue_id_.Binary()); +} + +void DataMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueDataMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_seq_id(seq_id_); + msg.set_msg_id_start(msg_id_start_); + msg.set_msg_id_end(msg_id_end_); + msg.set_length(buffer_->Size()); + msg.set_raw(raw_); + msg.SerializeToString(output); +} + +std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { + uint64_t *fbs_length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *fbs_length); + queue::protobuf::StreamingQueueDataMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t msg_id_start = message.msg_id_start(); + uint64_t msg_id_end = message.msg_id_end(); + uint64_t seq_id = message.seq_id(); + uint64_t length = message.length(); + bool raw = message.raw(); + bytes += *fbs_length; + + /// Copy data and create a new buffer for streaming queue. + std::shared_ptr buffer = + std::make_shared(bytes, (size_t)length, true); + std::shared_ptr data_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, seq_id, + msg_id_start, msg_id_end, buffer, raw); + + return data_msg; +} + +void NotificationMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueNotificationMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_seq_id(msg_id_); + msg.SerializeToString(output); +} + +std::shared_ptr NotificationMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueNotificationMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t seq_id = message.seq_id(); + + std::shared_ptr notify_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, seq_id); + + return notify_msg; +} + +void CheckMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueCheckMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.SerializeToString(output); +} + +std::shared_ptr CheckMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueCheckMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + + std::shared_ptr check_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id); + + return check_msg; +} + +void CheckRspMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueCheckRspMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_err_code(err_code_); + msg.SerializeToString(output); +} + +std::shared_ptr CheckRspMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueCheckRspMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + queue::protobuf::StreamingQueueError err_code = message.err_code(); + + std::shared_ptr check_rsp_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, err_code); + + return check_rsp_msg; +} + +void PullRequestMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueuePullRequestMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_msg_id(msg_id_); + msg.SerializeToString(output); +} + +std::shared_ptr PullRequestMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueuePullRequestMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t msg_id = message.msg_id(); + STREAMING_LOG(DEBUG) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id << " queue_id:" << queue_id + << " msg_id:" << msg_id; + + std::shared_ptr pull_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, msg_id); + return pull_msg; +} + +void PullResponseMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueuePullResponseMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_seq_id(seq_id_); + msg.set_msg_id(msg_id_); + msg.set_err_code(err_code_); + msg.set_is_upstream_first_pull(is_upstream_first_pull_); + msg.SerializeToString(output); +} + +std::shared_ptr PullResponseMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueuePullResponseMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t seq_id = message.seq_id(); + uint64_t msg_id = message.msg_id(); + queue::protobuf::StreamingQueueError err_code = message.err_code(); + bool is_upstream_first_pull = message.is_upstream_first_pull(); + + STREAMING_LOG(INFO) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id << " queue_id:" << queue_id + << " seq_id: " << seq_id << " msg_id: " << msg_id << " err_code:" + << queue::protobuf::StreamingQueueError_Name(err_code) + << " is_upstream_first_pull: " << is_upstream_first_pull; + + std::shared_ptr pull_rsp_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, seq_id, + msg_id, err_code, is_upstream_first_pull); + + return pull_rsp_msg; +} + +void ResendDataMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueResendDataMsg msg; + FillMessageCommon(msg.mutable_common()); + msg.set_first_seq_id(first_seq_id_); + msg.set_last_seq_id(last_seq_id_); + msg.set_seq_id(seq_id_); + msg.set_msg_id_start(msg_id_start_); + msg.set_msg_id_end(msg_id_end_); + msg.set_length(buffer_->Size()); + msg.set_raw(raw_); + msg.SerializeToString(output); +} + +std::shared_ptr ResendDataMessage::FromBytes(uint8_t *bytes) { + uint64_t *fbs_length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *fbs_length); + queue::protobuf::StreamingQueueResendDataMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.common().src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.common().dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.common().queue_id()); + uint64_t first_seq_id = message.first_seq_id(); + uint64_t last_seq_id = message.last_seq_id(); + uint64_t seq_id = message.seq_id(); + uint64_t msg_id_start = message.msg_id_start(); + uint64_t msg_id_end = message.msg_id_end(); + uint64_t length = message.length(); + bool raw = message.raw(); + + STREAMING_LOG(DEBUG) << "src_actor_id:" << src_actor_id + << " dst_actor_id:" << dst_actor_id + << " first_seq_id:" << first_seq_id << " seq_id:" << seq_id + << " msg_id_start: " << msg_id_start + << " msg_id_end: " << msg_id_end << " last_seq_id:" << last_seq_id + << " queue_id:" << queue_id << " length:" << length; + + bytes += *fbs_length; + /// COPY + std::shared_ptr buffer = + std::make_shared(bytes, (size_t)length, true); + std::shared_ptr pull_data_msg = std::make_shared( + src_actor_id, dst_actor_id, queue_id, first_seq_id, seq_id, msg_id_start, + msg_id_end, last_seq_id, buffer, raw); + + return pull_data_msg; +} + +void TestInitMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueTestInitMsg msg; + msg.set_role(role_); + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_actor_handle(actor_handle_serialized_); + for (auto &queue_id : queue_ids_) { + msg.add_queue_ids(queue_id.Binary()); + } + for (auto &queue_id : rescale_queue_ids_) { + msg.add_rescale_queue_ids(queue_id.Binary()); + } + msg.set_test_suite_name(test_suite_name_); + msg.set_test_name(test_name_); + msg.set_param(param_); + msg.SerializeToString(output); +} + +std::shared_ptr TestInitMessage::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueTestInitMsg message; + message.ParseFromString(inputpb); + queue::protobuf::StreamingQueueTestRole role = message.role(); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + std::string actor_handle_serialized = message.actor_handle(); + std::vector queue_ids; + for (int i = 0; i < message.queue_ids_size(); i++) { + queue_ids.push_back(ObjectID::FromBinary(message.queue_ids(i))); + } + std::vector rescale_queue_ids; + for (int i = 0; i < message.rescale_queue_ids_size(); i++) { + rescale_queue_ids.push_back(ObjectID::FromBinary(message.rescale_queue_ids(i))); + } + std::string test_suite_name = message.test_suite_name(); + std::string test_name = message.test_name(); + uint64_t param = message.param(); + + std::shared_ptr test_init_msg = std::make_shared( + role, src_actor_id, dst_actor_id, actor_handle_serialized, queue_ids, + rescale_queue_ids, test_suite_name, test_name, param); + + return test_init_msg; +} + +void TestCheckStatusRspMsg::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueTestCheckStatusRspMsg msg; + msg.set_test_name(test_name_); + msg.set_status(status_); + msg.SerializeToString(output); +} + +std::shared_ptr TestCheckStatusRspMsg::FromBytes(uint8_t *bytes) { + uint64_t *length = (uint64_t *)(bytes + kItemMetaHeaderSize); + bytes += kItemHeaderSize; + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueTestCheckStatusRspMsg message; + message.ParseFromString(inputpb); + std::string test_name = message.test_name(); + bool status = message.status(); + + std::shared_ptr test_check_msg = + std::make_shared(test_name, status); + + return test_check_msg; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/message.h b/streaming/src/queue/message.h new file mode 100644 index 00000000..9438a471 --- /dev/null +++ b/streaming/src/queue/message.h @@ -0,0 +1,335 @@ +#pragma once + +#include "protobuf/streaming_queue.pb.h" +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Base class of all message classes. +/// All payloads transferred through direct actor call are packed into a unified package, +/// consisting of protobuf-formatted metadata and data, including data and control +/// messages. These message classes wrap the package defined in +/// protobuf/streaming_queue.proto respectively. +class Message { + public: + /// Construct a Message instance. + /// \param[in] actor_id ActorID of message sender. + /// \param[in] peer_actor_id ActorID of message receiver. + /// \param[in] queue_id queue id to identify which queue the message is sent to. + /// \param[in] buffer an optional param, a chunk of data to send. + Message(const ActorID &actor_id, const ActorID &peer_actor_id, const ObjectID &queue_id, + std::shared_ptr buffer = nullptr) + : actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + queue_id_(queue_id), + buffer_(buffer) {} + Message() {} + virtual ~Message() {} + inline ActorID ActorId() { return actor_id_; } + inline ActorID PeerActorId() { return peer_actor_id_; } + inline ObjectID QueueId() { return queue_id_; } + inline std::shared_ptr Buffer() { return buffer_; } + + /// Serialize all meta data and data to a LocalMemoryBuffer, which can be sent through + /// direct actor call. \return serialized buffer . + std::unique_ptr ToBytes(); + + /// Get message type. + /// \return message type. + virtual queue::protobuf::StreamingQueueMessageType Type() = 0; + + /// All subclasses should implement `ToProtobuf` to serialize its own protobuf data. + virtual void ToProtobuf(std::string *output) = 0; + + void FillMessageCommon(queue::protobuf::MessageCommon *common); + + protected: + ActorID actor_id_; + ActorID peer_actor_id_; + ObjectID queue_id_; + std::shared_ptr buffer_; + + public: + /// A magic number to identify a valid message. + static const uint32_t MagicNum; +}; + +/// MagicNum + MessageType +constexpr uint32_t kItemMetaHeaderSize = + sizeof(Message::MagicNum) + sizeof(queue::protobuf::StreamingQueueMessageType); +/// kItemMetaHeaderSize + fbs length +constexpr uint32_t kItemHeaderSize = kItemMetaHeaderSize + sizeof(uint64_t); + +/// Wrap StreamingQueueDataMsg in streaming_queue.proto. +/// DataMessage encapsulates the memory buffer of QueueItem, a one-to-one relationship +/// exists between DataMessage and QueueItem. +class DataMessage : public Message { + public: + DataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id, + uint64_t seq_id, uint64_t msg_id_start, uint64_t msg_id_end, + std::shared_ptr buffer, bool raw) + : Message(actor_id, peer_actor_id, queue_id, buffer), + seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end), + raw_(raw) {} + virtual ~DataMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline bool IsRaw() { return raw_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; + bool raw_; + + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType; +}; + +/// Wrap StreamingQueueNotificationMsg in streaming_queue.proto. +/// NotificationMessage, downstream queues sends to upstream queues, for the data reader +/// to inform the data writer of the consumed offset. +class NotificationMessage : public Message { + public: + NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t msg_id) + : Message(actor_id, peer_actor_id, queue_id), msg_id_(msg_id) {} + + virtual ~NotificationMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + + inline uint64_t MsgId() { return msg_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t msg_id_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType; +}; + +/// Wrap StreamingQueueCheckMsg in streaming_queue.proto. +/// CheckMessage, upstream queues sends to downstream queues, fot the data writer to check +/// whether the corresponded downstream queue is read or not. +class CheckMessage : public Message { + public: + CheckMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id) + : Message(actor_id, peer_actor_id, queue_id) {} + virtual ~CheckMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType; +}; + +/// Wrap StreamingQueueCheckRspMsg in streaming_queue.proto. +/// CheckRspMessage, downstream queues sends to upstream queues, the response message to +/// CheckMessage to indicate whether downstream queue is ready or not. +class CheckRspMessage : public Message { + public: + CheckRspMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, queue::protobuf::StreamingQueueError err_code) + : Message(actor_id, peer_actor_id, queue_id), err_code_(err_code) {} + virtual ~CheckRspMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline queue::protobuf::StreamingQueueError Error() { return err_code_; } + + private: + queue::protobuf::StreamingQueueError err_code_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType; +}; + +class PullRequestMessage : public Message { + public: + PullRequestMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t msg_id) + : Message(actor_id, peer_actor_id, queue_id), msg_id_(msg_id) {} + virtual ~PullRequestMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t MsgId() { return msg_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t msg_id_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullRequestMsgType; +}; + +class PullResponseMessage : public Message { + public: + PullResponseMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t seq_id, uint64_t msg_id, + queue::protobuf::StreamingQueueError err_code, + bool is_upstream_first_pull) + : Message(actor_id, peer_actor_id, queue_id), + seq_id_(seq_id), + msg_id_(msg_id), + is_upstream_first_pull_(is_upstream_first_pull), + err_code_(err_code) {} + virtual ~PullResponseMessage() = default; + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgId() { return msg_id_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline queue::protobuf::StreamingQueueError Error() { return err_code_; } + inline bool IsUpstreamFirstPull() { return is_upstream_first_pull_; } + + private: + uint64_t seq_id_; + uint64_t msg_id_; + bool is_upstream_first_pull_; + queue::protobuf::StreamingQueueError err_code_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullResponseMsgType; +}; + +class ResendDataMessage : public Message { + public: + ResendDataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + ObjectID queue_id, uint64_t first_seq_id, uint64_t seq_id, + uint64_t msg_id_start, uint64_t msg_id_end, uint64_t last_seq_id, + std::shared_ptr buffer, bool raw) + : Message(actor_id, peer_actor_id, queue_id, buffer), + first_seq_id_(first_seq_id), + last_seq_id_(last_seq_id), + seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end), + raw_(raw) {} + virtual ~ResendDataMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t FirstSeqId() { return first_seq_id_; } + inline uint64_t LastSeqId() { return last_seq_id_; } + inline bool IsRaw() { return raw_; } + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t first_seq_id_; + uint64_t last_seq_id_; + uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; + bool raw_; + + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueResendDataMsgType; +}; + +/// Wrap StreamingQueueTestInitMsg in streaming_queue.proto. +/// TestInitMessage, used for test, driver sends to test workers to init test suite. +class TestInitMessage : public Message { + public: + TestInitMessage(const queue::protobuf::StreamingQueueTestRole role, + const ActorID &actor_id, const ActorID &peer_actor_id, + const std::string actor_handle_serialized, + const std::vector &queue_ids, + const std::vector &rescale_queue_ids, + std::string test_suite_name, std::string test_name, uint64_t param) + : Message(actor_id, peer_actor_id, queue_ids[0]), + actor_handle_serialized_(actor_handle_serialized), + queue_ids_(queue_ids), + rescale_queue_ids_(rescale_queue_ids), + role_(role), + test_suite_name_(test_suite_name), + test_name_(test_name), + param_(param) {} + virtual ~TestInitMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline std::string ActorHandleSerialized() { return actor_handle_serialized_; } + inline queue::protobuf::StreamingQueueTestRole Role() { return role_; } + inline std::vector QueueIds() { return queue_ids_; } + inline std::vector RescaleQueueIds() { return rescale_queue_ids_; } + inline std::string TestSuiteName() { return test_suite_name_; } + inline std::string TestName() { return test_name_; } + inline uint64_t Param() { return param_; } + + std::string ToString() { + std::ostringstream os; + os << "actor_handle_serialized: " << actor_handle_serialized_; + os << " actor_id: " << ActorId(); + os << " peer_actor_id: " << PeerActorId(); + os << " queue_ids:["; + for (auto &qid : queue_ids_) { + os << qid << ","; + } + os << "], rescale_queue_ids:["; + for (auto &qid : rescale_queue_ids_) { + os << qid << ","; + } + os << "],"; + os << " role:" << queue::protobuf::StreamingQueueTestRole_Name(role_); + os << " suite_name: " << test_suite_name_; + os << " test_name: " << test_name_; + os << " param: " << param_; + return os.str(); + } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType; + std::string actor_handle_serialized_; + std::vector queue_ids_; + std::vector rescale_queue_ids_; + queue::protobuf::StreamingQueueTestRole role_; + std::string test_suite_name_; + std::string test_name_; + uint64_t param_; +}; + +/// Wrap StreamingQueueTestCheckStatusRspMsg in streaming_queue.proto. +/// TestCheckStatusRspMsg, used for test, driver sends to test workers to check +/// whether test has completed or failed. +class TestCheckStatusRspMsg : public Message { + public: + TestCheckStatusRspMsg(const std::string test_name, bool status) + : test_name_(test_name), status_(status) {} + virtual ~TestCheckStatusRspMsg() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } + inline std::string TestName() { return test_name_; } + inline bool Status() { return status_; } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestCheckStatusRspMsgType; + std::string test_name_; + bool status_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc new file mode 100644 index 00000000..1885a460 --- /dev/null +++ b/streaming/src/queue/queue.cc @@ -0,0 +1,334 @@ +#include "queue/queue.h" + +#include +#include + +#include "queue/queue_handler.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +bool Queue::Push(QueueItem item) { + std::unique_lock lock(mutex_); + if (max_data_size_ < item.DataSize() + data_size_) return false; + + buffer_queue_.push_back(item); + data_size_ += item.DataSize(); + readable_cv_.notify_one(); + return true; +} + +QueueItem Queue::FrontProcessed() { + std::unique_lock lock(mutex_); + STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail"; + + if (watershed_iter_ == buffer_queue_.begin()) { + return InvalidQueueItem(); + } + + QueueItem item = buffer_queue_.front(); + return item; +} + +QueueItem Queue::PopProcessed() { + std::unique_lock lock(mutex_); + STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail"; + + if (watershed_iter_ == buffer_queue_.begin()) { + return InvalidQueueItem(); + } + + QueueItem item = buffer_queue_.front(); + buffer_queue_.pop_front(); + data_size_ -= item.DataSize(); + data_size_sent_ -= item.DataSize(); + return item; +} + +QueueItem Queue::PopPending() { + std::unique_lock lock(mutex_); + auto it = std::next(watershed_iter_); + QueueItem item = *it; + data_size_sent_ += it->DataSize(); + buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it)); + return item; +} + +QueueItem Queue::PopPendingBlockTimeout(uint64_t timeout_us) { + std::unique_lock lock(mutex_); + std::chrono::system_clock::time_point point = + std::chrono::system_clock::now() + std::chrono::microseconds(timeout_us); + if (readable_cv_.wait_until(lock, point, [this] { + return std::next(watershed_iter_) != buffer_queue_.end(); + })) { + auto it = std::next(watershed_iter_); + QueueItem item = *it; + data_size_sent_ += it->DataSize(); + buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it)); + return item; + + } else { + return InvalidQueueItem(); + } +} + +QueueItem Queue::BackPending() { + std::unique_lock lock(mutex_); + if (std::next(watershed_iter_) == buffer_queue_.end()) { + return InvalidQueueItem(); + } + return buffer_queue_.back(); +} + +size_t Queue::ProcessedCount() { + std::unique_lock lock(mutex_); + if (watershed_iter_ == buffer_queue_.begin()) return 0; + + auto begin = buffer_queue_.begin(); + auto end = std::prev(watershed_iter_); + + return end->SeqId() + 1 - begin->SeqId(); +} + +size_t Queue::PendingCount() { + std::unique_lock lock(mutex_); + if (std::next(watershed_iter_) == buffer_queue_.end()) return 0; + + auto begin = std::next(watershed_iter_); + auto end = std::prev(buffer_queue_.end()); + + return begin->SeqId() - end->SeqId() + 1; +} + +Status WriterQueue::Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end, bool raw) { + if (IsPendingFull(buffer_size)) { + return Status::OutOfMemory("Queue Push OutOfMemory"); + } + + while (is_resending_) { + STREAMING_LOG(INFO) << "This queue is resending data, wait."; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + QueueItem item(seq_id_, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); + Queue::Push(item); + STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id: " << seq_id_; + seq_id_++; + return Status::OK(); +} + +void WriterQueue::Send() { + while (!IsPendingEmpty()) { + QueueItem item = PopPending(); + DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.MsgIdStart(), + item.MsgIdEnd(), item.Buffer(), item.IsRaw()); + std::unique_ptr buffer = msg.ToBytes(); + STREAMING_CHECK(transport_ != nullptr); + transport_->Send(std::move(buffer)); + } +} + +Status WriterQueue::TryEvictItems() { + QueueItem item = FrontProcessed(); + STREAMING_LOG(DEBUG) << "TryEvictItems queue_id: " << queue_id_ << " first_item: (" + << item.MsgIdStart() << "," << item.MsgIdEnd() << ")" + << " min_consumed_msg_id_: " << min_consumed_msg_id_ + << " eviction_limit_: " << eviction_limit_ + << " max_data_size_: " << max_data_size_ + << " data_size_sent_: " << data_size_sent_ + << " data_size_: " << data_size_; + + if (min_consumed_msg_id_ == QUEUE_INVALID_SEQ_ID || + min_consumed_msg_id_ < item.MsgIdEnd()) { + return Status::OutOfMemory("The queue is full and some reader doesn't consume"); + } + + if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || eviction_limit_ < item.MsgIdEnd()) { + return Status::OutOfMemory("The queue is full and eviction limit block evict"); + } + + uint64_t evict_target_msg_id = std::min(min_consumed_msg_id_, eviction_limit_); + + int count = 0; + while (item.MsgIdEnd() <= evict_target_msg_id) { + PopProcessed(); + STREAMING_LOG(INFO) << "TryEvictItems directly " << item.MsgIdEnd(); + item = FrontProcessed(); + count++; + } + STREAMING_LOG(DEBUG) << count << " items evicted, current item: (" << item.MsgIdStart() + << "," << item.MsgIdEnd() << ")"; + return Status::OK(); +} + +void WriterQueue::OnNotify(std::shared_ptr notify_msg) { + STREAMING_LOG(INFO) << "OnNotify target msg_id: " << notify_msg->MsgId(); + min_consumed_msg_id_ = notify_msg->MsgId(); +} + +void WriterQueue::ResendItem(QueueItem &item, uint64_t first_seq_id, + uint64_t last_seq_id) { + ResendDataMessage msg(actor_id_, peer_actor_id_, queue_id_, first_seq_id, item.SeqId(), + item.MsgIdStart(), item.MsgIdEnd(), last_seq_id, item.Buffer(), + item.IsRaw()); + STREAMING_CHECK(item.Buffer()->Data() != nullptr); + std::unique_ptr buffer = msg.ToBytes(); + + transport_->Send(std::move(buffer)); +} + +int WriterQueue::ResendItems(std::list::iterator start_iter, + uint64_t first_seq_id, uint64_t last_seq_id) { + std::unique_lock lock(mutex_); + int count = 0; + auto it = start_iter; + for (; it != watershed_iter_; it++) { + if (it->SeqId() > last_seq_id) { + break; + } + STREAMING_LOG(INFO) << "ResendItems send seq_id " << it->SeqId() << " to peer."; + ResendItem(*it, first_seq_id, last_seq_id); + count++; + } + + STREAMING_LOG(INFO) << "ResendItems total count: " << count; + is_resending_ = false; + return count; +} + +void WriterQueue::FindItem( + uint64_t target_msg_id, std::function greater_callback, + std::function less_callback, + std::function::iterator, uint64_t, uint64_t)> + equal_callback) { + auto last_one = std::prev(watershed_iter_); + bool last_item_too_small = + last_one != buffer_queue_.end() && last_one->MsgIdEnd() < target_msg_id; + + if (QUEUE_INITIAL_SEQ_ID == seq_id_ || last_item_too_small) { + greater_callback(); + return; + } + + auto begin = buffer_queue_.begin(); + uint64_t first_seq_id = (*begin).SeqId(); + uint64_t last_seq_id = first_seq_id + std::distance(begin, watershed_iter_) - 1; + STREAMING_LOG(INFO) << "FindItem last_seq_id: " << last_seq_id + << " first_seq_id: " << first_seq_id; + + auto target_item = std::find_if( + begin, watershed_iter_, + [&target_msg_id](QueueItem &item) { return item.InItem(target_msg_id); }); + + if (target_item != watershed_iter_) { + equal_callback(target_item, first_seq_id, last_seq_id); + } else { + less_callback(); + } +} + +void WriterQueue::OnPull( + std::shared_ptr pull_msg, boost::asio::io_service &service, + std::function)> callback) { + std::unique_lock lock(mutex_); + STREAMING_CHECK(peer_actor_id_ == pull_msg->ActorId()) + << peer_actor_id_ << " " << pull_msg->ActorId(); + + FindItem(pull_msg->MsgId(), + /// target_msg_id is too large. + [this, &pull_msg, &callback]() { + STREAMING_LOG(WARNING) + << "No valid data to pull, the writer has not push data yet. "; + PullResponseMessage msg(pull_msg->PeerActorId(), pull_msg->ActorId(), + pull_msg->QueueId(), QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::NO_VALID_DATA, + is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + is_upstream_first_pull_ = false; + callback(std::move(buffer)); + }, + /// target_msg_id is too small. + [this, &pull_msg, &callback]() { + STREAMING_LOG(WARNING) << "Data lost."; + PullResponseMessage msg(pull_msg->PeerActorId(), pull_msg->ActorId(), + pull_msg->QueueId(), QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::DATA_LOST, + is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + + callback(std::move(buffer)); + }, + /// target_msg_id found. + [this, &pull_msg, &callback, &service]( + std::list::iterator target_item, uint64_t first_seq_id, + uint64_t last_seq_id) { + is_resending_ = true; + STREAMING_LOG(INFO) << "OnPull return"; + service.post(std::bind(&WriterQueue::ResendItems, this, target_item, + first_seq_id, last_seq_id)); + PullResponseMessage msg( + pull_msg->PeerActorId(), pull_msg->ActorId(), pull_msg->QueueId(), + target_item->SeqId(), pull_msg->MsgId(), + queue::protobuf::StreamingQueueError::OK, is_upstream_first_pull_); + std::unique_ptr buffer = msg.ToBytes(); + is_upstream_first_pull_ = false; + callback(std::move(buffer)); + }); +} + +void ReaderQueue::OnConsumed(uint64_t msg_id) { + STREAMING_LOG(INFO) << "OnConsumed: " << msg_id; + QueueItem item = FrontProcessed(); + while (item.MsgIdEnd() <= msg_id) { + PopProcessed(); + item = FrontProcessed(); + } + Notify(msg_id); +} + +void ReaderQueue::Notify(uint64_t msg_id) { + std::vector task_args; + CreateNotifyTask(msg_id, task_args); + // SubmitActorTask + + NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, msg_id); + std::unique_ptr buffer = msg.ToBytes(); + + transport_->Send(std::move(buffer)); +} + +void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_args) {} + +void ReaderQueue::OnData(QueueItem &item) { + last_recv_seq_id_ = item.SeqId(); + last_recv_msg_id_ = item.MsgIdEnd(); + STREAMING_LOG(DEBUG) << "ReaderQueue::OnData queue_id: " << queue_id_ + << " seq_id: " << last_recv_seq_id_ << " msg_id: (" + << item.MsgIdStart() << "," << item.MsgIdEnd() << ")"; + + Push(item); +} + +void ReaderQueue::OnResendData(std::shared_ptr msg) { + STREAMING_LOG(INFO) << "OnResendData queue_id: " << queue_id_ << " recv seq_id " + << msg->SeqId() << "(" << msg->FirstSeqId() << "/" + << msg->LastSeqId() << ")"; + QueueItem item(msg->SeqId(), msg->Buffer(), 0, msg->MsgIdStart(), msg->MsgIdEnd(), + msg->IsRaw()); + STREAMING_CHECK(msg->Buffer()->Data() != nullptr); + + Push(item); + STREAMING_CHECK(msg->SeqId() >= msg->FirstSeqId() && msg->SeqId() <= msg->LastSeqId()) + << "(" << msg->FirstSeqId() << "/" << msg->SeqId() << "/" << msg->LastSeqId() + << ")"; + if (msg->SeqId() == msg->LastSeqId()) { + STREAMING_LOG(INFO) << "Resend DATA Done"; + } +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue.h b/streaming/src/queue/queue.h new file mode 100644 index 00000000..758b9093 --- /dev/null +++ b/streaming/src/queue/queue.h @@ -0,0 +1,272 @@ +#pragma once + +#include +#include +#include + +#include "queue/queue_item.h" +#include "queue/transport.h" +#include "queue/utils.h" +#include "ray/common/id.h" +#include "ray/util/util.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +using ray::ObjectID; + +enum QueueType { UPSTREAM = 0, DOWNSTREAM }; + +/// A queue-like data structure, which does not delete its items after poped. +/// The lifecycle of each item is: +/// - Pending, an item is pushed into a queue, but has not been processed (sent out or +/// consumed), +/// - Processed, has been handled by the user, but should not be deleted. +/// - Evicted, useless to the user, should be poped and destroyed. +/// At present, this data structure is implemented with one std::list, +/// using a watershed iterator to divided. +class Queue { + public: + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param queue_id the unique identification of a pair of queues (upstream and + /// downstream). + /// \param size max size of the queue in bytes. + /// \param transport + /// transport to send items to peer. + Queue(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id, + uint64_t size, std::shared_ptr transport) + : actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + queue_id_(queue_id), + max_data_size_(size), + data_size_(0), + data_size_sent_(0) { + buffer_queue_.push_back(InvalidQueueItem()); + watershed_iter_ = buffer_queue_.begin(); + } + + virtual ~Queue() {} + + /// Push an item into the queue. + /// \param[in] item the QueueItem object to be send to peer. + /// \return false if the queue is full. + bool Push(QueueItem item); + + /// Get the front of item which in processed state. + QueueItem FrontProcessed(); + + /// Pop the front of item which in processed state. + QueueItem PopProcessed(); + + /// Pop the front of item which in pending state, the item + /// will not be evicted at this moment, its state turn to + /// processed. + QueueItem PopPending(); + + /// PopPending with timeout in microseconds. + QueueItem PopPendingBlockTimeout(uint64_t timeout_us); + + /// Return the last item in pending state. + QueueItem BackPending(); + + inline bool IsPendingEmpty() { + std::unique_lock lock(mutex_); + return std::next(watershed_iter_) == buffer_queue_.end(); + }; + + inline bool IsPendingFull(uint64_t data_size = 0) { + std::unique_lock lock(mutex_); + return max_data_size_ < data_size + data_size_; + } + + /// Return the size in bytes of all items in queue. + inline uint64_t QueueSize() { return data_size_; } + + /// Return the size in bytes of all items in pending state. + inline uint64_t PendingDataSize() { return data_size_ - data_size_sent_; } + + /// Return the size in bytes of all items in processed state. + inline uint64_t ProcessedDataSize() { return data_size_sent_; } + + /// Return item count of the queue. + inline size_t Count() { return buffer_queue_.size(); } + + /// Return item count in pending state. + inline size_t PendingCount(); + + /// Return item count in processed state. + inline size_t ProcessedCount(); + + inline ActorID GetActorID() { return actor_id_; } + inline ActorID GetPeerActorID() { return peer_actor_id_; } + inline ObjectID GetQueueID() { return queue_id_; } + + protected: + std::list buffer_queue_; + std::list::iterator watershed_iter_; + + ActorID actor_id_; + ActorID peer_actor_id_; + ObjectID queue_id_; + /// max data size in bytes + uint64_t max_data_size_; + uint64_t data_size_; + uint64_t data_size_sent_; + + std::mutex mutex_; + std::condition_variable readable_cv_; +}; + +/// Queue in upstream. +class WriterQueue : public Queue { + public: + /// \param queue_id, the unique ObjectID to identify a queue + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param size, max data size in bytes + /// \param transport, transport + WriterQueue(const ObjectID &queue_id, const ActorID &actor_id, + const ActorID &peer_actor_id, uint64_t size, + std::shared_ptr transport) + : Queue(actor_id, peer_actor_id, queue_id, size, transport), + actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + seq_id_(QUEUE_INITIAL_SEQ_ID), + eviction_limit_(QUEUE_INVALID_SEQ_ID), + min_consumed_msg_id_(QUEUE_INVALID_SEQ_ID), + peer_last_msg_id_(0), + peer_last_seq_id_(QUEUE_INVALID_SEQ_ID), + transport_(transport), + is_resending_(false), + is_upstream_first_pull_(true) {} + + /// Push a continuous buffer into queue, the buffer consists of some messages packed by + /// DataWriter. + /// \param data, the buffer address + /// \param data_size, buffer size + /// \param timestamp, the timestamp when the buffer pushed in + /// \param msg_id_start, the message id of the first message in the buffer + /// \param msg_id_end, the message id of the last message in the buffer + /// \param raw, whether this buffer is raw data, be True only in test + Status Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false); + + /// Callback function, will be called when downstream queue notifies + /// it has consumed some items. + /// NOTE: this callback function is called in queue thread. + void OnNotify(std::shared_ptr notify_msg); + + /// Callback function, will be called when downstream queue receives + /// resend items form upstream queue. + /// NOTE: this callback function is called in queue thread. + void OnPull(std::shared_ptr pull_msg, + boost::asio::io_service &service, + std::function)> callback); + + /// Send items through direct call. + void Send(); + + /// Called when user pushs item into queue. The count of items + /// can be evicted, determined by eviction_limit_ and min_consumed_msg_id_. + Status TryEvictItems(); + + void SetQueueEvictionLimit(uint64_t msg_id) { eviction_limit_ = msg_id; } + + uint64_t EvictionLimit() { return eviction_limit_; } + + uint64_t GetMinConsumedMsgID() { return min_consumed_msg_id_; } + + void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) { + peer_last_msg_id_ = msg_id; + peer_last_seq_id_ = seq_id; + } + + uint64_t GetPeerLastMsgId() { return peer_last_msg_id_; } + + uint64_t GetPeerLastSeqId() { return peer_last_seq_id_; } + + private: + /// Resend an item to peer. + /// \param item, the item object reference to ben resend. + /// \param first_seq_id, the seq id of the first item in this resend sequence. + /// \param last_seq_id, the seq id of the last item in this resend sequence. + void ResendItem(QueueItem &item, uint64_t first_seq_id, uint64_t last_seq_id); + /// Resend items to peer from start_iter iterator to watershed_iter_. + /// \param start_iter, the starting list iterator. + /// \param first_seq_id, the seq id of the first item in this resend sequence. + /// \param last_seq_id, the seq id of the last item in this resend sequence. + int ResendItems(std::list::iterator start_iter, uint64_t first_seq_id, + uint64_t last_seq_id); + /// Find the item which the message with `target_msg_id` in. If the `target_msg_id` + /// is larger than the largest message id in the queue, the `greater_callback` callback + /// will be called; If the `target_message_id` is smaller than the smallest message id + /// in the queue, the `less_callback` callback will be called; If the `target_msg_id` is + /// found in the queue, the `found_callback` callback willbe called. + /// \param target_msg_id, the target message id to be found. + void FindItem(uint64_t target_msg_id, std::function greater_callback, + std::function less_callback, + std::function::iterator, uint64_t, uint64_t)> + equal_callback); + + private: + ActorID actor_id_; + ActorID peer_actor_id_; + uint64_t seq_id_; + uint64_t eviction_limit_; + uint64_t min_consumed_msg_id_; + uint64_t peer_last_msg_id_; + uint64_t peer_last_seq_id_; + std::shared_ptr transport_; + + std::atomic is_resending_; + bool is_upstream_first_pull_; +}; + +/// Queue in downstream. +class ReaderQueue : public Queue { + public: + /// \param queue_id, the unique ObjectID to identify a queue + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param transport, transport + /// NOTE: we do not restrict queue size of ReaderQueue + ReaderQueue(const ObjectID &queue_id, const ActorID &actor_id, + const ActorID &peer_actor_id, std::shared_ptr transport) + : Queue(actor_id, peer_actor_id, queue_id, std::numeric_limits::max(), + transport), + actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + last_recv_seq_id_(QUEUE_INVALID_SEQ_ID), + last_recv_msg_id_(QUEUE_INVALID_SEQ_ID), + transport_(transport) {} + + /// Delete processed items whose seq id <= seq_id, + /// then notify upstream queue. + void OnConsumed(uint64_t seq_id); + + void OnData(QueueItem &item); + /// Callback function, will be called when PullPeer DATA comes. + /// TODO: can be combined with OnData + /// NOTE: this callback function is called in queue thread. + void OnResendData(std::shared_ptr msg); + + inline uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } + inline uint64_t GetLastRecvMsgId() { return last_recv_msg_id_; } + + private: + void Notify(uint64_t seq_id); + void CreateNotifyTask(uint64_t seq_id, std::vector &task_args); + + private: + ActorID actor_id_; + ActorID peer_actor_id_; + uint64_t last_recv_seq_id_; + uint64_t last_recv_msg_id_; + std::shared_ptr promise_for_pull_; + std::shared_ptr transport_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue_client.cc b/streaming/src/queue/queue_client.cc new file mode 100644 index 00000000..df62e1ca --- /dev/null +++ b/streaming/src/queue/queue_client.cc @@ -0,0 +1,25 @@ +#include "queue/queue_client.h" + +namespace ray { +namespace streaming { + +void WriterClient::OnWriterMessage(std::shared_ptr buffer) { + upstream_handler_->DispatchMessageAsync(buffer); +} + +std::shared_ptr WriterClient::OnWriterMessageSync( + std::shared_ptr buffer) { + return upstream_handler_->DispatchMessageSync(buffer); +} + +void ReaderClient::OnReaderMessage(std::shared_ptr buffer) { + downstream_handler_->DispatchMessageAsync(buffer); +} + +std::shared_ptr ReaderClient::OnReaderMessageSync( + std::shared_ptr buffer) { + return downstream_handler_->DispatchMessageSync(buffer); +} + +} // namespace streaming +} // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/queue_client.h b/streaming/src/queue/queue_client.h new file mode 100644 index 00000000..fec23d66 --- /dev/null +++ b/streaming/src/queue/queue_client.h @@ -0,0 +1,52 @@ +#pragma once + +#include "queue/queue_handler.h" +#include "queue/transport.h" + +namespace ray { +namespace streaming { + +/// The interface of the streaming queue for DataReader. +/// A ReaderClient should be created before DataReader created in Cython/Jni, and hold by +/// Jobworker. When DataReader receive a buffer from upstream DataWriter (DataReader's +/// raycall function is called), it calls `OnReaderMessage` to pass the buffer to its own +/// downstream queue, or `OnReaderMessageSync` to wait for handle result. +class ReaderClient { + public: + /// Construct a ReaderClient object. + /// \param[in] async_func DataReader's raycall function descriptor to be called by + /// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall + /// function descriptor to be called by DataWriter, synchronous semantics + ReaderClient() { + downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService( + CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); + } + + /// Post buffer to downstream queue service, asynchronously. + void OnReaderMessage(std::shared_ptr buffer); + /// Post buffer to downstream queue service, synchronously. + /// \return handle result. + std::shared_ptr OnReaderMessageSync( + std::shared_ptr buffer); + + private: + std::shared_ptr downstream_handler_; +}; + +/// Interface of streaming queue for DataWriter. Similar to ReaderClient. +class WriterClient { + public: + WriterClient() { + upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService( + CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); + } + + void OnWriterMessage(std::shared_ptr buffer); + std::shared_ptr OnWriterMessageSync( + std::shared_ptr buffer); + + private: + std::shared_ptr upstream_handler_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc new file mode 100644 index 00000000..6b59d8e4 --- /dev/null +++ b/streaming/src/queue/queue_handler.cc @@ -0,0 +1,484 @@ +#include "queue/queue_handler.h" + +#include "queue/utils.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +constexpr uint64_t COMMON_SYNC_CALL_TIMEOUTT_MS = 5 * 1000; + +std::shared_ptr + UpstreamQueueMessageHandler::upstream_handler_ = nullptr; +std::shared_ptr + DownstreamQueueMessageHandler::downstream_handler_ = nullptr; + +std::shared_ptr QueueMessageHandler::ParseMessage( + std::shared_ptr buffer) { + uint8_t *bytes = buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum) + << *magic_num << " " << Message::MagicNum; + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + + std::shared_ptr message = nullptr; + switch (*type) { + case queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType: + message = NotificationMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType: + message = DataMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType: + message = CheckMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType: + message = CheckRspMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueuePullRequestMsgType: + message = PullRequestMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueuePullResponseMsgType: + message = PullResponseMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueResendDataMsgType: + message = ResendDataMessage::FromBytes(bytes); + break; + default: + STREAMING_CHECK(false) << "nonsupport message type: " + << queue::protobuf::StreamingQueueMessageType_Name(*type); + break; + } + + return message; +} + +void QueueMessageHandler::DispatchMessageAsync( + std::shared_ptr buffer) { + queue_service_.post( + boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, nullptr)); +} + +std::shared_ptr QueueMessageHandler::DispatchMessageSync( + std::shared_ptr buffer) { + std::shared_ptr result = nullptr; + std::shared_ptr promise = std::make_shared(); + queue_service_.post( + boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, + [&promise, &result](std::shared_ptr rst) { + result = rst; + promise->Notify(ray::Status::OK()); + })); + Status st = promise->Wait(); + STREAMING_CHECK(st.ok()); + + return result; +} + +std::shared_ptr QueueMessageHandler::GetOutTransport( + const ObjectID &queue_id) { + auto it = out_transports_.find(queue_id); + if (it == out_transports_.end()) return nullptr; + + return it->second; +} + +void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id, + const ActorID &actor_id, RayFunction &async_func, + RayFunction &sync_func) { + actors_.emplace(queue_id, actor_id); + out_transports_.emplace(queue_id, std::make_shared( + actor_id, async_func, sync_func)); +} + +ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) { + auto it = actors_.find(queue_id); + STREAMING_CHECK(it != actors_.end()); + return it->second; +} + +void QueueMessageHandler::Release() { + actors_.clear(); + out_transports_.clear(); +} + +void QueueMessageHandler::Start() { + queue_thread_ = std::thread(&QueueMessageHandler::QueueThreadCallback, this); +} + +void QueueMessageHandler::Stop() { + STREAMING_LOG(INFO) << "QueueMessageHandler Stop."; + queue_service_.stop(); + if (queue_thread_.joinable()) { + queue_thread_.join(); + } +} + +void UpstreamQueueMessageHandler::Start() { + STREAMING_LOG(INFO) << "UpstreamQueueMessageHandler::Start"; + QueueMessageHandler::Start(); + handle_service_thread_ = std::thread([this] { handler_service_.run(); }); +} + +void UpstreamQueueMessageHandler::Stop() { + STREAMING_LOG(INFO) << "UpstreamQueueMessageHandler::Stop"; + handler_service_.stop(); + if (handle_service_thread_.joinable()) { + handle_service_thread_.join(); + } + QueueMessageHandler::Stop(); +} + +std::shared_ptr UpstreamQueueMessageHandler::CreateService( + const ActorID &actor_id) { + if (nullptr == upstream_handler_) { + upstream_handler_ = std::make_shared(actor_id); + } + return upstream_handler_; +} + +std::shared_ptr UpstreamQueueMessageHandler::GetService() { + return upstream_handler_; +} + +std::shared_ptr UpstreamQueueMessageHandler::CreateUpstreamQueue( + const ObjectID &queue_id, const ActorID &peer_actor_id, uint64_t size) { + STREAMING_LOG(INFO) << "CreateUpstreamQueue: " << queue_id << " " << actor_id_ << "->" + << peer_actor_id; + std::shared_ptr queue = GetUpQueue(queue_id); + if (queue != nullptr) { + STREAMING_LOG(WARNING) << "Duplicate to create up queue." << queue_id; + return queue; + } + + queue = std::make_unique(queue_id, actor_id_, peer_actor_id, + size, GetOutTransport(queue_id)); + upstream_queues_[queue_id] = queue; + + return queue; +} + +bool UpstreamQueueMessageHandler::UpstreamQueueExists(const ObjectID &queue_id) { + return nullptr != GetUpQueue(queue_id); +} + +std::shared_ptr UpstreamQueueMessageHandler::GetUpQueue( + const ObjectID &queue_id) { + auto it = upstream_queues_.find(queue_id); + if (it == upstream_queues_.end()) return nullptr; + + return it->second; +} + +bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) { + ActorID peer_actor_id = GetPeerActorID(queue_id); + STREAMING_LOG(INFO) << "CheckQueueSync queue_id: " << queue_id + << " peer_actor_id: " << peer_actor_id; + + CheckMessage msg(actor_id_, peer_actor_id, queue_id); + std::unique_ptr buffer = msg.ToBytes(); + + auto transport_it = GetOutTransport(queue_id); + STREAMING_CHECK(transport_it != nullptr); + std::shared_ptr result_buffer = transport_it->SendForResultWithRetry( + std::move(buffer), 10, COMMON_SYNC_CALL_TIMEOUTT_MS); + if (result_buffer == nullptr) { + return false; + } + + std::shared_ptr result_msg = ParseMessage(result_buffer); + STREAMING_CHECK( + result_msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType); + std::shared_ptr check_rsp_msg = + std::dynamic_pointer_cast(result_msg); + STREAMING_LOG(INFO) << "CheckQueueSync return queue_id: " << check_rsp_msg->QueueId(); + STREAMING_CHECK(check_rsp_msg->PeerActorId() == actor_id_); + + return queue::protobuf::StreamingQueueError::OK == check_rsp_msg->Error(); +} + +void UpstreamQueueMessageHandler::WaitQueues(const std::vector &queue_ids, + int64_t timeout_ms, + std::vector &failed_queues) { + failed_queues.insert(failed_queues.begin(), queue_ids.begin(), queue_ids.end()); + uint64_t start_time_us = current_time_ms(); + uint64_t current_time_us = start_time_us; + while (!failed_queues.empty() && current_time_us < start_time_us + timeout_ms * 1000) { + for (auto it = failed_queues.begin(); it != failed_queues.end();) { + if (CheckQueueSync(*it)) { + STREAMING_LOG(INFO) << "Check queue: " << *it << " return, ready."; + it = failed_queues.erase(it); + } else { + STREAMING_LOG(INFO) << "Check queue: " << *it << " return, not ready."; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + it++; + } + } + current_time_us = current_time_ms(); + } +} + +void UpstreamQueueMessageHandler::DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) { + std::shared_ptr msg = ParseMessage(buffer); + STREAMING_LOG(DEBUG) << "UpstreamQueueMessageHandler::DispatchMessageInternal: " + << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() + << " peer actorid: " << msg->PeerActorId() << " type: " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); + + if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType) { + OnNotify(std::dynamic_pointer_cast(msg)); + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType) { + STREAMING_CHECK(false) << "Should not receive StreamingQueueCheckRspMsg"; + } else if (msg->Type() == queue::protobuf::StreamingQueueMessageType:: + StreamingQueuePullRequestMsgType) { + STREAMING_CHECK(callback) << "StreamingQueuePullRequestMsg " + << " qid: " << msg->QueueId() << " actorid " + << msg->ActorId() + << " peer actorid: " << msg->PeerActorId(); + OnPullRequest(std::dynamic_pointer_cast(msg), callback); + } else { + STREAMING_CHECK(false) << "message type should be added: " + << queue::protobuf::StreamingQueueMessageType_Name( + msg->Type()); + } +} + +void UpstreamQueueMessageHandler::OnNotify( + std::shared_ptr notify_msg) { + auto queue = GetUpQueue(notify_msg->QueueId()); + if (queue == nullptr) { + STREAMING_LOG(WARNING) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name( + notify_msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " msg id: " << notify_msg->MsgId(); + return; + } + queue->OnNotify(notify_msg); +} + +void UpstreamQueueMessageHandler::OnPullRequest( + std::shared_ptr pull_msg, + std::function)> callback) { + STREAMING_LOG(INFO) << "OnPullRequest"; + auto queue = upstream_queues_.find(pull_msg->QueueId()); + if (queue == upstream_queues_.end()) { + STREAMING_LOG(INFO) << "Can not find queue " << pull_msg->QueueId(); + PullResponseMessage msg(pull_msg->PeerActorId(), pull_msg->ActorId(), + pull_msg->QueueId(), QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID, + queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST, false); + std::unique_ptr buffer = msg.ToBytes(); + callback(std::move(buffer)); + return; + } + + queue->second->OnPull(pull_msg, handler_service_, callback); +} + +void UpstreamQueueMessageHandler::ReleaseAllUpQueues() { + STREAMING_LOG(INFO) << "ReleaseAllUpQueues"; + upstream_queues_.clear(); + Release(); +} + +std::shared_ptr +DownstreamQueueMessageHandler::CreateService(const ActorID &actor_id) { + if (nullptr == downstream_handler_) { + STREAMING_LOG(INFO) << "DownstreamQueueMessageHandler::CreateService " + << " actorid: " << actor_id; + downstream_handler_ = std::make_shared(actor_id); + } + return downstream_handler_; +} + +std::shared_ptr +DownstreamQueueMessageHandler::GetService() { + return downstream_handler_; +} + +bool DownstreamQueueMessageHandler::DownstreamQueueExists(const ObjectID &queue_id) { + return nullptr != GetDownQueue(queue_id); +} + +std::shared_ptr DownstreamQueueMessageHandler::CreateDownstreamQueue( + const ObjectID &queue_id, const ActorID &peer_actor_id) { + STREAMING_LOG(INFO) << "CreateDownstreamQueue: " << queue_id << " " << peer_actor_id + << "->" << actor_id_; + auto it = downstream_queues_.find(queue_id); + if (it != downstream_queues_.end()) { + STREAMING_LOG(WARNING) << "Duplicate to create down queue!!!! " << queue_id; + return it->second; + } + + std::shared_ptr queue = + std::make_unique(queue_id, actor_id_, peer_actor_id, + GetOutTransport(queue_id)); + downstream_queues_[queue_id] = queue; + return queue; +} + +StreamingQueueStatus DownstreamQueueMessageHandler::PullQueue( + const ObjectID &queue_id, uint64_t start_msg_id, bool &is_upstream_first_pull, + uint64_t timeout_ms) { + STREAMING_LOG(INFO) << "PullQueue queue_id: " << queue_id + << " start_msg_id: " << start_msg_id + << " is_upstream_first_pull: " << is_upstream_first_pull; + uint64_t start_time = current_time_ms(); + uint64_t current_time = start_time; + StreamingQueueStatus st = StreamingQueueStatus::OK; + while (current_time < start_time + timeout_ms && + (st = PullPeerAsync(queue_id, start_msg_id, is_upstream_first_pull, + timeout_ms)) == StreamingQueueStatus::Timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + current_time = current_time_ms(); + } + return st; +} + +std::shared_ptr DownstreamQueueMessageHandler::GetDownQueue( + const ObjectID &queue_id) { + auto it = downstream_queues_.find(queue_id); + if (it == downstream_queues_.end()) return nullptr; + + return it->second; +} + +std::shared_ptr DownstreamQueueMessageHandler::OnCheckQueue( + std::shared_ptr check_msg) { + queue::protobuf::StreamingQueueError err_code = + queue::protobuf::StreamingQueueError::OK; + + auto down_queue = downstream_queues_.find(check_msg->QueueId()); + if (down_queue == downstream_queues_.end()) { + STREAMING_LOG(WARNING) << "OnCheckQueue " << check_msg->QueueId() << " not found."; + err_code = queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST; + } + + CheckRspMessage msg(check_msg->PeerActorId(), check_msg->ActorId(), + check_msg->QueueId(), err_code); + std::shared_ptr buffer = msg.ToBytes(); + + return buffer; +} + +void DownstreamQueueMessageHandler::ReleaseAllDownQueues() { + STREAMING_LOG(INFO) << "ReleaseAllDownQueues size: " << downstream_queues_.size(); + downstream_queues_.clear(); + Release(); +} + +void DownstreamQueueMessageHandler::DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) { + std::shared_ptr msg = ParseMessage(buffer); + STREAMING_LOG(DEBUG) << "DownstreamQueueMessageHandler::DispatchMessageInternal: " + << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() + << " peer actorid: " << msg->PeerActorId() << " type: " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); + + if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType) { + OnData(std::dynamic_pointer_cast(msg)); + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType) { + std::shared_ptr check_result = + this->OnCheckQueue(std::dynamic_pointer_cast(msg)); + if (callback != nullptr) { + callback(check_result); + } + } else if (msg->Type() == queue::protobuf::StreamingQueueMessageType:: + StreamingQueueResendDataMsgType) { + auto queue = downstream_queues_.find(msg->QueueId()); + if (queue == downstream_queues_.end()) { + std::shared_ptr data_msg = + std::dynamic_pointer_cast(msg); + STREAMING_LOG(DEBUG) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " seq id: " << data_msg->SeqId(); + return; + } + std::shared_ptr resend_data_msg = + std::dynamic_pointer_cast(msg); + + queue->second->OnResendData(resend_data_msg); + } else { + STREAMING_CHECK(false) << "message type should be added: " + << queue::protobuf::StreamingQueueMessageType_Name( + msg->Type()); + } +} + +void DownstreamQueueMessageHandler::OnData(std::shared_ptr msg) { + auto queue = GetDownQueue(msg->QueueId()); + if (queue == nullptr) { + STREAMING_LOG(WARNING) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " seq id: " << msg->SeqId(); + return; + } + + QueueItem item(msg); + queue->OnData(item); +} + +StreamingQueueStatus DownstreamQueueMessageHandler::PullPeerAsync( + const ObjectID &queue_id, uint64_t start_msg_id, bool &is_upstream_first_pull, + uint64_t timeout_ms) { + STREAMING_LOG(INFO) << "PullPeerAsync queue_id: " << queue_id + << " start_msg_id: " << start_msg_id; + auto queue = GetDownQueue(queue_id); + STREAMING_CHECK(queue != nullptr); + STREAMING_LOG(INFO) << "PullPeerAsync " + << " actorid: " << queue->GetActorID(); + PullRequestMessage msg(queue->GetActorID(), queue->GetPeerActorID(), queue_id, + start_msg_id); + std::unique_ptr buffer = msg.ToBytes(); + + auto transport_it = GetOutTransport(queue_id); + STREAMING_CHECK(transport_it != nullptr); + std::shared_ptr result_buffer = + transport_it->SendForResultWithRetry(std::move(buffer), 1, timeout_ms); + if (result_buffer == nullptr) { + return StreamingQueueStatus::Timeout; + } + + std::shared_ptr result_msg = ParseMessage(result_buffer); + STREAMING_CHECK( + result_msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueuePullResponseMsgType); + std::shared_ptr response_msg = + std::dynamic_pointer_cast(result_msg); + + STREAMING_LOG(INFO) << "PullPeerAsync error: " + << queue::protobuf::StreamingQueueError_Name(response_msg->Error()) + << " start_msg_id: " << start_msg_id; + + is_upstream_first_pull = response_msg->IsUpstreamFirstPull(); + if (response_msg->Error() == queue::protobuf::StreamingQueueError::OK) { + STREAMING_LOG(INFO) << "Set queue " << queue_id << " expect_seq_id to " + << response_msg->SeqId(); + return StreamingQueueStatus::OK; + } else if (response_msg->Error() == queue::protobuf::StreamingQueueError::DATA_LOST) { + return StreamingQueueStatus::DataLost; + } else if (response_msg->Error() == + queue::protobuf::StreamingQueueError::NO_VALID_DATA) { + return StreamingQueueStatus::NoValidData; + } else { // QUEUE_NOT_EXIST + return StreamingQueueStatus::Timeout; + } +} + +} // namespace streaming +} // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h new file mode 100644 index 00000000..26bd863e --- /dev/null +++ b/streaming/src/queue/queue_handler.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include + +#include "queue/queue.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +using namespace ray::core; + +enum class StreamingQueueStatus : uint32_t { + OK = 0, + Timeout = 1, + DataLost = 2, // The data in upstream has been evicted when downstream try to pull data + // from upstream. + NoValidData = 3, // There is no data written into queue, or start_msg_id is bigger than + // all items in queue now. +}; + +static inline std::ostream &operator<<(std::ostream &os, + const StreamingQueueStatus &status) { + os << static_cast::type>(status); + return os; +} + +/// Base class of UpstreamQueueMessageHandler and DownstreamQueueMessageHandler. +/// A queue service manages a group of queues, upstream queues or downstream queues of +/// the current actor. Each queue service holds a boost.asio io_service, to handle +/// messages asynchronously. When a message received by Writer/Reader in ray call thread, +/// the message was delivered to +/// UpstreamQueueMessageHandler/DownstreamQueueMessageHandler, then the ray call thread +/// returns immediately. The queue service parses meta infomation from the message, +/// including queue_id actor_id, etc, and dispatchs message to queue according to +/// queue_id. +class QueueMessageHandler { + public: + /// Construct a QueueMessageHandler instance. + /// \param[in] actor_id actor id of current actor. + QueueMessageHandler(const ActorID &actor_id) + : actor_id_(actor_id), queue_dummy_work_(queue_service_) {} + + virtual ~QueueMessageHandler() { Stop(); } + + /// Dispatch message buffer to asio service. + /// \param[in] buffer serialized message received from peer actor. + void DispatchMessageAsync(std::shared_ptr buffer); + + /// Dispatch message buffer to asio service synchronously, and wait for handle result. + /// \param[in] buffer serialized message received from peer actor. + /// \return handle result. + std::shared_ptr DispatchMessageSync( + std::shared_ptr buffer); + + /// Get transport to a peer actor specified by actor_id. + /// \param[in] actor_id actor id of peer actor + /// \return transport + std::shared_ptr GetOutTransport(const ObjectID &actor_id); + + /// The actual function where message being dispatched, called by DispatchMessageAsync + /// and DispatchMessageSync. + /// \param[in] buffer serialized message received from peer actor. + /// \param[in] callback the callback function used by DispatchMessageSync, called + /// after message processed complete. The std::shared_ptr + /// parameter is the return value. + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) = 0; + + /// Save actor_id of the peer actor specified by queue_id. For a upstream queue, the + /// peer actor refer specifically to the actor in current ray cluster who has a + /// downstream queue with same queue_id, and vice versa. + /// \param[in] queue_id queue id of current queue. + /// \param[in] actor_id actor_id actor id of corresponded peer actor. + void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id, + RayFunction &async_func, RayFunction &sync_func); + + /// Obtain the actor id of the peer actor specified by queue_id. + /// \return actor id + ActorID GetPeerActorID(const ObjectID &queue_id); + + /// Release all queues in current queue service. + void Release(); + + protected: + /// Start asio service + virtual void Start(); + /// Stop asio service + virtual void Stop(); + /// The callback function of internal thread. + void QueueThreadCallback() { queue_service_.run(); } + + protected: + /// actor_id actor id of current actor + ActorID actor_id_; + /// Helper function, parse message buffer to Message object. + std::shared_ptr ParseMessage(std::shared_ptr buffer); + + private: + /// Map from queue id to a actor id of the queue's peer actor. + std::unordered_map actors_; + /// Map from queue id to a transport of the queue's peer actor. + std::unordered_map> out_transports_; + /// The internal thread which asio service run with. + std::thread queue_thread_; + /// The internal asio service. + boost::asio::io_service queue_service_; + /// The asio work which keeps queue_service_ alive. + boost::asio::io_service::work queue_dummy_work_; +}; + +/// UpstreamQueueMessageHandler holds and manages all upstream queues of current actor. +class UpstreamQueueMessageHandler : public QueueMessageHandler { + public: + /// Construct a UpstreamQueueMessageHandler instance. + UpstreamQueueMessageHandler(const ActorID &actor_id) + : QueueMessageHandler(actor_id), handler_service_dummy_worker_(handler_service_) { + Start(); + } + /// Create a upstream queue. + /// \param[in] queue_id queue id of the queue to be created. + /// \param[in] peer_actor_id actor id of peer actor. + /// \param[in] size the max memory size of the queue. + std::shared_ptr CreateUpstreamQueue(const ObjectID &queue_id, + const ActorID &peer_actor_id, + uint64_t size); + /// Check whether the upstream queue specified by queue_id exists or not. + bool UpstreamQueueExists(const ObjectID &queue_id); + /// Wait all queues in queue_ids vector ready, until timeout. + /// \param[in] queue_ids a group of queues. + /// \param[in] timeout_ms max timeout time interval for wait all queues. + /// \param[out] failed_queues a group of queues which are not ready when timeout. + void WaitQueues(const std::vector &queue_ids, int64_t timeout_ms, + std::vector &failed_queues); + /// Handle notify message from corresponded downstream queue. + void OnNotify(std::shared_ptr notify_msg); + /// Handle pull request message from corresponded downstream queue. + void OnPullRequest(std::shared_ptr pull_msg, + std::function)> callback); + /// Obtain upstream queue specified by queue_id. + std::shared_ptr GetUpQueue(const ObjectID &queue_id); + /// Release all upstream queues + void ReleaseAllUpQueues(); + + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) override; + + static std::shared_ptr CreateService( + const ActorID &actor_id); + static std::shared_ptr GetService(); + virtual void Start() override; + + private: + bool CheckQueueSync(const ObjectID &queue_ids); + virtual void Stop() override; + + private: + std::unordered_map> upstream_queues_; + static std::shared_ptr upstream_handler_; + boost::asio::io_service handler_service_; + boost::asio::io_service::work handler_service_dummy_worker_; + std::thread handle_service_thread_; +}; + +/// DownstreamQueueMessageHandler holds and manages all downstream queues of current +/// actor. +class DownstreamQueueMessageHandler : public QueueMessageHandler { + public: + DownstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id) { + Start(); + } + /// Create a downstream queue. + /// \param queue_id, queue id of the queue to be created. + /// \param peer_actor_id, actor id of peer actor. + std::shared_ptr CreateDownstreamQueue(const ObjectID &queue_id, + const ActorID &peer_actor_id); + /// Request to pull messages from corresponded upstream queue, whose message id + /// is larger than `start_msg_id`. Multiple attempts to pull until timeout. + /// \param queue_id, queue id of the queue to be pulled. + /// \param start_msg_id, the starting message id reqeust by downstream queue. + /// \param is_upstream_first_pull + /// \param timeout_ms, the maxmium timeout. + StreamingQueueStatus PullQueue(const ObjectID &queue_id, uint64_t start_msg_id, + bool &is_upstream_first_pull, + uint64_t timeout_ms = 2000); + /// Check whether the downstream queue specified by queue_id exists or not. + bool DownstreamQueueExists(const ObjectID &queue_id); + std::shared_ptr OnCheckQueue( + std::shared_ptr check_msg); + /// Obtain downstream queue specified by queue_id. + std::shared_ptr GetDownQueue(const ObjectID &queue_id); + /// Release all downstream queues + void ReleaseAllDownQueues(); + /// The callback function called when downstream queue receives a queue item. + void OnData(std::shared_ptr msg); + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback); + static std::shared_ptr CreateService( + const ActorID &actor_id); + static std::shared_ptr GetService(); + StreamingQueueStatus PullPeerAsync(const ObjectID &queue_id, uint64_t start_msg_id, + bool &is_upstream_first_pull, uint64_t timeout_ms); + + private: + std::unordered_map> + downstream_queues_; + static std::shared_ptr downstream_handler_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue_item.h b/streaming/src/queue/queue_item.h new file mode 100644 index 00000000..b63e0eb7 --- /dev/null +++ b/streaming/src/queue/queue_item.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include +#include + +#include "message/message_bundle.h" +#include "queue/message.h" +#include "ray/common/id.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +using ray::ObjectID; +const uint64_t QUEUE_INVALID_SEQ_ID = std::numeric_limits::max(); +const uint64_t QUEUE_INITIAL_SEQ_ID = 1; + +/// QueueItem is the element stored in `Queue`. Actually, when DataWriter pushes a message +/// bundle into a queue, the bundle is packed into one QueueItem, so a one-to-one +/// relationship exists between message bundle and QueueItem. Meanwhile, the QueueItem is +/// also the minimum unit to send through direct actor call. Each QueueItem holds a +/// LocalMemoryBuffer shared_ptr, which will be sent out by Transport. +class QueueItem { + public: + QueueItem() = default; + /// Construct a QueueItem object. + /// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and + /// QueueItem. + /// \param[in] data the data buffer to be stored in this QueueItem. + /// \param[in] data_size the data size in bytes. + /// \param[in] timestamp the time when this QueueItem created. + /// \param[in] raw whether the data content is raw bytes, only used in some tests. + QueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false) + : seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end), + timestamp_(timestamp), + raw_(raw), + /*COPY*/ buffer_(std::make_shared(data, data_size, true)) {} + + QueueItem(uint64_t seq_id, std::shared_ptr buffer, + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, + bool raw = false) + : seq_id_(seq_id), + msg_id_start_(msg_id_start), + msg_id_end_(msg_id_end), + timestamp_(timestamp), + raw_(raw), + buffer_(buffer) {} + + QueueItem(std::shared_ptr data_msg) + : seq_id_(data_msg->SeqId()), + msg_id_start_(data_msg->MsgIdStart()), + msg_id_end_(data_msg->MsgIdEnd()), + raw_(data_msg->IsRaw()), + buffer_(data_msg->Buffer()) {} + + QueueItem(const QueueItem &&item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + } + + QueueItem(const QueueItem &item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + } + + QueueItem &operator=(const QueueItem &item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + msg_id_start_ = item.msg_id_start_; + msg_id_end_ = item.msg_id_end_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + return *this; + } + + virtual ~QueueItem() = default; + + inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgIdStart() { return msg_id_start_; } + inline uint64_t MsgIdEnd() { return msg_id_end_; } + inline bool InItem(uint64_t msg_id) { + return msg_id >= msg_id_start_ && msg_id <= msg_id_end_; + } + inline bool IsRaw() { return raw_; } + inline uint64_t TimeStamp() { return timestamp_; } + inline size_t DataSize() { return buffer_->Size(); } + inline std::shared_ptr Buffer() { return buffer_; } + + /// Get max message id in this item. + /// \return max message id. + uint64_t MaxMsgId() { + if (raw_) { + return 0; + } + auto message_bundle = StreamingMessageBundleMeta::FromBytes(buffer_->Data()); + return message_bundle->GetLastMessageId(); + } + + protected: + uint64_t seq_id_; + uint64_t msg_id_start_; + uint64_t msg_id_end_; + uint64_t timestamp_; + bool raw_; + + std::shared_ptr buffer_; +}; + +class InvalidQueueItem : public QueueItem { + public: + InvalidQueueItem() + : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0, QUEUE_INVALID_SEQ_ID, + QUEUE_INVALID_SEQ_ID) {} + + private: + uint8_t data_[1]; +}; +typedef std::shared_ptr QueueItemPtr; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc new file mode 100644 index 00000000..bdee9284 --- /dev/null +++ b/streaming/src/queue/transport.cc @@ -0,0 +1,66 @@ +#include "queue/transport.h" + +#include "queue/utils.h" +#include "ray/common/common_protocol.h" +#include "ray/streaming/streaming.h" + +namespace ray { +namespace streaming { + +static constexpr int TASK_OPTION_RETURN_NUM_0 = 0; +static constexpr int TASK_OPTION_RETURN_NUM_1 = 1; + +void Transport::Send(std::shared_ptr buffer) { + STREAMING_LOG(DEBUG) << "Transport::Send buffer size: " << buffer->Size(); + RAY_UNUSED(ray::streaming::SendInternal(peer_actor_id_, std::move(buffer), async_func_, + TASK_OPTION_RETURN_NUM_0)); +} + +std::shared_ptr Transport::SendForResult( + std::shared_ptr buffer, int64_t timeout_ms) { + auto return_refs = ray::streaming::SendInternal(peer_actor_id_, buffer, sync_func_, + TASK_OPTION_RETURN_NUM_1); + auto return_ids = ObjectRefsToIds(return_refs); + + std::vector> results; + Status get_st = + CoreWorkerProcess::GetCoreWorker().Get(return_ids, timeout_ms, &results); + if (!get_st.ok()) { + STREAMING_LOG(ERROR) << "Get fail."; + return nullptr; + } + STREAMING_CHECK(results.size() >= 1); + if (results[0]->IsException()) { + STREAMING_LOG(ERROR) << "peer actor may has exceptions, should retry."; + return nullptr; + } + STREAMING_CHECK(results[0]->HasData()); + if (results[0]->GetData()->Size() == 4) { + STREAMING_LOG(WARNING) << "peer actor may not ready yet, should retry."; + return nullptr; + } + + std::shared_ptr result_buffer = results[0]->GetData(); + std::shared_ptr return_buffer = std::make_shared( + result_buffer->Data(), result_buffer->Size(), true); + return return_buffer; +} + +std::shared_ptr Transport::SendForResultWithRetry( + std::shared_ptr buffer, int retry_cnt, int64_t timeout_ms) { + STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt + << " timeout_ms: " << timeout_ms; + std::shared_ptr buffer_shared = std::move(buffer); + for (int cnt = 0; cnt < retry_cnt; cnt++) { + auto result = SendForResult(buffer_shared, timeout_ms); + if (result != nullptr) { + return result; + } + } + + STREAMING_LOG(WARNING) << "SendForResultWithRetry fail after retry."; + return nullptr; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/transport.h b/streaming/src/queue/transport.h new file mode 100644 index 00000000..7e3a6bca --- /dev/null +++ b/streaming/src/queue/transport.h @@ -0,0 +1,59 @@ +#pragma once + +#include "ray/common/id.h" +#include "ray/core_worker/core_worker.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +using namespace ray::core; + +/// Transport is the transfer endpoint to a specific actor, buffers can be sent to peer +/// through direct actor call. +class Transport { + public: + /// Construct a Transport object. + /// \param[in] peer_actor_id actor id of peer actor. + Transport(const ActorID &peer_actor_id, RayFunction &async_func, RayFunction &sync_func) + : peer_actor_id_(peer_actor_id), async_func_(async_func), sync_func_(sync_func) { + STREAMING_LOG(INFO) << "Transport constructor:"; + STREAMING_LOG(INFO) << "async_func lang: " << async_func_.GetLanguage(); + STREAMING_LOG(INFO) << "async_func: " + << async_func_.GetFunctionDescriptor()->ToString(); + STREAMING_LOG(INFO) << "sync_func lang: " << sync_func_.GetLanguage(); + STREAMING_LOG(INFO) << "sync_func: " + << sync_func_.GetFunctionDescriptor()->ToString(); + } + + virtual ~Transport() = default; + + /// Send buffer asynchronously, peer's `function` will be called. + /// \param[in] buffer buffer to be sent. + virtual void Send(std::shared_ptr buffer); + + /// Send buffer synchronously, peer's `function` will be called, and return the peer + /// function's return value. + /// \param[in] buffer buffer to be sent. + /// \param[in] timeout_ms max time to wait for result. + /// \return peer function's result. + virtual std::shared_ptr SendForResult( + std::shared_ptr buffer, int64_t timeout_ms); + + /// Send buffer and get result with retry. + /// return value. + /// \param[in] buffer buffer to be sent. + /// \param[in] max retry count + /// \param[in] timeout_ms max time to wait for result. + /// \return peer function's result. + std::shared_ptr SendForResultWithRetry( + std::shared_ptr buffer, int retry_cnt, int64_t timeout_ms); + + private: + WorkerID worker_id_; + ActorID peer_actor_id_; + RayFunction async_func_; + RayFunction sync_func_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/utils.h b/streaming/src/queue/utils.h new file mode 100644 index 00000000..3838f2f6 --- /dev/null +++ b/streaming/src/queue/utils.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include + +#include "ray/util/util.h" + +namespace ray { +namespace streaming { + +/// Helper class encapulate std::future to help multithread async wait. +class PromiseWrapper { + public: + Status Wait() { + std::future fut = promise_.get_future(); + fut.get(); + return status_; + } + + Status WaitFor(uint64_t timeout_ms) { + std::future fut = promise_.get_future(); + std::future_status status; + do { + status = fut.wait_for(std::chrono::milliseconds(timeout_ms)); + if (status == std::future_status::deferred) { + } else if (status == std::future_status::timeout) { + return Status::Invalid("timeout"); + } else if (status == std::future_status::ready) { + return status_; + } + } while (status == std::future_status::deferred); + + return status_; + } + + void Notify(Status status) { + status_ = status; + promise_.set_value(true); + } + + Status GetResultStatus() { return status_; } + + private: + std::promise promise_; + Status status_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability/barrier_helper.cc b/streaming/src/reliability/barrier_helper.cc new file mode 100644 index 00000000..14d66b79 --- /dev/null +++ b/streaming/src/reliability/barrier_helper.cc @@ -0,0 +1,165 @@ +#include "barrier_helper.h" + +#include + +#include "util/streaming_logging.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { +StreamingStatus StreamingBarrierHelper::GetMsgIdByBarrierId(const ObjectID &q_id, + uint64_t barrier_id, + uint64_t &msg_id) { + std::lock_guard lock(global_barrier_mutex_); + auto queue_map = global_barrier_map_.find(barrier_id); + if (queue_map == global_barrier_map_.end()) { + return StreamingStatus::NoSuchItem; + } + auto msg_id_map = queue_map->second.find(q_id); + if (msg_id_map == queue_map->second.end()) { + return StreamingStatus::QueueIdNotFound; + } + msg_id = msg_id_map->second; + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::SetMsgIdByBarrierId(const ObjectID &q_id, + uint64_t barrier_id, uint64_t msg_id) { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_[barrier_id][q_id] = msg_id; +} + +void StreamingBarrierHelper::ReleaseBarrierMapById(uint64_t barrier_id) { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_.erase(barrier_id); +} + +void StreamingBarrierHelper::ReleaseAllBarrierMap() { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_.clear(); +} + +void StreamingBarrierHelper::MapBarrierToCheckpoint(uint64_t barrier_id, + uint64_t checkpoint) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + barrier_checkpoint_map_[barrier_id] = checkpoint; +} + +StreamingStatus StreamingBarrierHelper::GetCheckpointIdByBarrierId( + uint64_t barrier_id, uint64_t &checkpoint_id) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + auto checkpoint_item = barrier_checkpoint_map_.find(barrier_id); + if (checkpoint_item == barrier_checkpoint_map_.end()) { + return StreamingStatus::NoSuchItem; + } + + checkpoint_id = checkpoint_item->second; + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::ReleaseBarrierMapCheckpointByBarrierId( + const uint64_t barrier_id) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + auto it = barrier_checkpoint_map_.begin(); + while (it != barrier_checkpoint_map_.end()) { + if (it->first <= barrier_id) { + it = barrier_checkpoint_map_.erase(it); + } else { + it++; + } + } +} + +StreamingStatus StreamingBarrierHelper::GetBarrierIdByLastMessageId(const ObjectID &q_id, + uint64_t message_id, + uint64_t &barrier_id, + bool is_pop) { + std::lock_guard lock(message_id_map_barrier_mutex_); + auto message_item = global_reversed_barrier_map_.find(message_id); + if (message_item == global_reversed_barrier_map_.end()) { + return StreamingStatus::NoSuchItem; + } + + auto message_queue_item = message_item->second.find(q_id); + if (message_queue_item == message_item->second.end()) { + return StreamingStatus::QueueIdNotFound; + } + if (message_queue_item->second->empty()) { + STREAMING_LOG(WARNING) << "[Barrier] q id => " << q_id.Hex() << ", str num => " + << Util::Hexqid2str(q_id.Hex()) << ", message id " + << message_id; + return StreamingStatus::NoSuchItem; + } else { + barrier_id = message_queue_item->second->front(); + if (is_pop) { + message_queue_item->second->pop(); + } + } + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::SetBarrierIdByLastMessageId(const ObjectID &q_id, + uint64_t message_id, + uint64_t barrier_id) { + std::lock_guard lock(message_id_map_barrier_mutex_); + + auto max_message_id_barrier = max_message_id_map_.find(q_id); + // remove finished barrier in different last message id + if (max_message_id_barrier != max_message_id_map_.end() && + max_message_id_barrier->second != message_id) { + if (global_reversed_barrier_map_.find(max_message_id_barrier->second) != + global_reversed_barrier_map_.end()) { + global_reversed_barrier_map_.erase(max_message_id_barrier->second); + } + } + + max_message_id_map_[q_id] = message_id; + auto message_item = global_reversed_barrier_map_.find(message_id); + if (message_item == global_reversed_barrier_map_.end()) { + BarrierIdQueue temp_queue = std::make_shared>(); + temp_queue->push(barrier_id); + global_reversed_barrier_map_[message_id][q_id] = temp_queue; + return; + } + auto message_queue_item = message_item->second.find(q_id); + if (message_queue_item != message_item->second.end()) { + message_queue_item->second->push(barrier_id); + } else { + BarrierIdQueue temp_queue = std::make_shared>(); + temp_queue->push(barrier_id); + global_reversed_barrier_map_[message_id][q_id] = temp_queue; + } +} + +void StreamingBarrierHelper::GetAllBarrier(std::vector &barrier_id_vec) { + std::transform( + global_barrier_map_.begin(), global_barrier_map_.end(), + std::back_inserter(barrier_id_vec), + [](std::unordered_map>::value_type + pair) { return pair.first; }); +} + +bool StreamingBarrierHelper::Contains(uint64_t barrier_id) { + return global_barrier_map_.find(barrier_id) != global_barrier_map_.end(); +} + +uint32_t StreamingBarrierHelper::GetBarrierMapSize() { + return global_barrier_map_.size(); +} + +void StreamingBarrierHelper::GetCurrentMaxCheckpointIdInQueue( + const ObjectID &q_id, uint64_t &checkpoint_id) const { + auto item = current_max_checkpoint_id_map_.find(q_id); + if (item != current_max_checkpoint_id_map_.end()) { + checkpoint_id = item->second; + } else { + checkpoint_id = 0; + } +} + +void StreamingBarrierHelper::SetCurrentMaxCheckpointIdInQueue( + const ObjectID &q_id, const uint64_t checkpoint_id) { + current_max_checkpoint_id_map_[q_id] = checkpoint_id; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability/barrier_helper.h b/streaming/src/reliability/barrier_helper.h new file mode 100644 index 00000000..23f0e11c --- /dev/null +++ b/streaming/src/reliability/barrier_helper.h @@ -0,0 +1,120 @@ +#pragma once +#include +#include + +#include "common/status.h" +#include "ray/common/id.h" + +namespace ray { +namespace streaming { +class StreamingBarrierHelper final { + using BarrierIdQueue = std::shared_ptr>; + + public: + StreamingBarrierHelper() {} + /// No duplicated barrier helper should be loaded in data writer or data + /// reader, so we mark BarrierHelper as a nocopyable object. + StreamingBarrierHelper(const StreamingBarrierHelper &barrier_helper) = delete; + + StreamingBarrierHelper operator=(const StreamingBarrierHelper &barrier_helper) = delete; + + virtual ~StreamingBarrierHelper() = default; + + /// Get barrier id from queue-barrier map by given seq-id. + /// \param_in q_id, channel id + /// \param_in barrier_id, barrier or checkpoint of long runtime job + /// \param_out msg_id, message id of barrier + StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, + uint64_t &msg_id); + + /// Append new message id to queue-barrier map. + /// \param_in q_id, channel id + /// \param_in barrier_id, barrier or checkpoint of long running job + /// \param_in msg_id, message id of barrier + void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t msg_id); + + /// Check whether barrier id in queue-barrier map. + /// \param_in barrier_id, barrier id or checkpoint id + bool Contains(uint64_t barrier_id); + + /// Remove barrier info from queue-barrier map by given seq id. + void ReleaseBarrierMapById(uint64_t barrier_id); + + /// Remove all barrier info from queue-barrier map. + void ReleaseAllBarrierMap(); + + /// Fetch barrier id list from queue-barrier map. + void GetAllBarrier(std::vector &barrier_id_vec); + + /// Get barrier map capacity of current version. + uint32_t GetBarrierMapSize(); + + /// We assume there are multiple barriers in one checkpoint, so barrier id + /// should belong to a checkpoint id. + /// \param_in barrier_id, barrier id + /// \param_in checkpoint_id, checkpoint id + void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint_id); + + /// Get checkpoint id by given barrier id + /// \param_in barrier_id, barrier id + /// \param_out checkpoint_id, checkpoint id + StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id, + uint64_t &checkpoint_id); + + /// Clear barrier-checkpoint relation if elements of barrier id vector are + /// equal to or less than given barrier id. + /// \param_in barrier_id + void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id); + + /// Get barrier id by lastest message id and channel + /// \param_in q_id, channel id + /// \param_in message_id, lastest message id of barrier data + /// \param_out barrier_id, barrier id + /// \param_in is_pop, whether pop out from queue + StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t &barrier_id, bool is_pop = false); + + /// Put new barrier id in map by channel index and lastest message id. + /// \param_in q_id, channel id + /// \param_in message_id, lastest message id of barrier data + /// \param_in barrier_id, barrier id + void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t barrier_id); + + /// \param_in q_id, channel id + /// \param_in checkpoint_id, checkpoint id of long running job + void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + uint64_t &checkpoint_id) const; + + /// \param_in q_id, channel id + /// \param_in checkpoint_id, checkpoint id of long running job + void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + const uint64_t checkpoint_id); + + private: + // Global barrier map set (global barrier id -> (channel id -> seq id)) + std::unordered_map> + global_barrier_map_; + + // Message id map to barrier id of each queue(continuous barriers hold same last message + // id) + // message id -> (queue id -> list(barrier id)). + // Thread unsafe to assign value in user's thread but collect it in loopforward thread. + std::unordered_map> + global_reversed_barrier_map_; + + std::unordered_map barrier_checkpoint_map_; + + std::unordered_map max_message_id_map_; + + // We assume default max checkpoint is 0. + std::unordered_map current_max_checkpoint_id_map_; + + std::mutex message_id_map_barrier_mutex_; + + std::mutex global_barrier_mutex_; + + std::mutex barrier_map_checkpoint_mutex_; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability_helper.cc b/streaming/src/reliability_helper.cc new file mode 100644 index 00000000..9e4a083a --- /dev/null +++ b/streaming/src/reliability_helper.cc @@ -0,0 +1,113 @@ +#include "reliability_helper.h" + +#include +namespace ray { +namespace streaming { + +std::shared_ptr ReliabilityHelperFactory::CreateReliabilityHelper( + const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) { + if (config.IsExactlyOnce()) { + return std::make_shared(config, barrier_helper, writer, reader); + } else { + return std::make_shared(config, barrier_helper, writer, reader); + } +} + +ReliabilityHelper::ReliabilityHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : config_(config), + barrier_helper_(barrier_helper), + writer_(writer), + reader_(reader) {} + +void ReliabilityHelper::Reload() {} + +bool ReliabilityHelper::StoreBundleMeta(ProducerChannelInfo &channel_info, + StreamingMessageBundlePtr &bundle_ptr, + bool is_replay) { + return false; +} + +bool ReliabilityHelper::FilterMessage(ProducerChannelInfo &channel_info, + const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) { + bool is_filtered = false; + uint64_t &message_id = channel_info.current_message_id; + uint64_t last_msg_id = channel_info.message_last_commit_id; + + if (StreamingMessageType::Barrier == message_type) { + is_filtered = message_id < last_msg_id; + } else { + message_id++; + // Message last commit id is the last item in queue or restore from queue. + // It skip directly since message id is less or equal than current commit id. + is_filtered = message_id <= last_msg_id && !config_.IsAtLeastOnce(); + } + *write_message_id = message_id; + + return is_filtered; +} + +void ReliabilityHelper::CleanupCheckpoint(ProducerChannelInfo &channel_info, + uint64_t barrier_id) {} + +StreamingStatus ReliabilityHelper::InitChannelMerger(uint32_t timeout) { + return reader_->InitChannelMerger(timeout); +} + +StreamingStatus ReliabilityHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { + STREAMING_LOG(DEBUG) << "[Reader] Queue " << channel_info.channel_id + << " get item timeout, resend notify " + << channel_info.current_message_id; + reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); + return StreamingStatus::OK; +} + +AtLeastOnceHelper::AtLeastOnceHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : ReliabilityHelper(config, barrier_helper, writer, reader) {} + +StreamingStatus AtLeastOnceHelper::InitChannelMerger(uint32_t timeout) { + // No merge in AT_LEAST_ONCE + return StreamingStatus::OK; +} + +StreamingStatus AtLeastOnceHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { + if (current_sys_time_ms() - channel_info.resend_notify_timer > + StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL) { + STREAMING_LOG(INFO) << "[Reader] Queue " << channel_info.channel_id + << " get item timeout, resend notify " + << channel_info.current_message_id; + reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); + channel_info.resend_notify_timer = current_sys_time_ms(); + } + return StreamingStatus::Invalid; +} + +ExactlyOnceHelper::ExactlyOnceHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : ReliabilityHelper(config, barrier_helper, writer, reader) {} + +bool ExactlyOnceHelper::FilterMessage(ProducerChannelInfo &channel_info, + const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) { + bool is_filtered = ReliabilityHelper::FilterMessage(channel_info, data, message_type, + write_message_id); + if (is_filtered && StreamingMessageType::Barrier == message_type && + StreamingRole::SOURCE == config_.GetStreamingRole()) { + *write_message_id = channel_info.message_last_commit_id; + // Do not skip source barrier when it's reconstructing from downstream. + is_filtered = false; + STREAMING_LOG(INFO) << "append barrier to buffer ring " << *write_message_id + << ", last commit id " << channel_info.message_last_commit_id; + } + return is_filtered; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability_helper.h b/streaming/src/reliability_helper.h new file mode 100644 index 00000000..56089a08 --- /dev/null +++ b/streaming/src/reliability_helper.h @@ -0,0 +1,66 @@ +#pragma once +#include "channel/channel.h" +#include "data_reader.h" +#include "data_writer.h" +#include "reliability/barrier_helper.h" +#include "util/config.h" + +namespace ray { +namespace streaming { + +class ReliabilityHelper; +class DataWriter; +class DataReader; + +class ReliabilityHelperFactory { + public: + static std::shared_ptr CreateReliabilityHelper( + const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); +}; + +class ReliabilityHelper { + public: + ReliabilityHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + virtual ~ReliabilityHelper() = default; + // Only exactly same need override this function. + virtual void Reload(); + // Store bundle meta or skip in replay mode. + virtual bool StoreBundleMeta(ProducerChannelInfo &channel_info, + StreamingMessageBundlePtr &bundle_ptr, + bool is_replay = false); + virtual void CleanupCheckpoint(ProducerChannelInfo &channel_info, uint64_t barrier_id); + // Filter message by different failover strategies. + virtual bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id); + virtual StreamingStatus InitChannelMerger(uint32_t timeout); + virtual StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info); + + protected: + const StreamingConfig &config_; + StreamingBarrierHelper &barrier_helper_; + DataWriter *writer_; + DataReader *reader_; +}; + +class AtLeastOnceHelper : public ReliabilityHelper { + public: + AtLeastOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + StreamingStatus InitChannelMerger(uint32_t timeout) override; + StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info) override; +}; + +class ExactlyOnceHelper : public ReliabilityHelper { + public: + ExactlyOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) override; + virtual ~ExactlyOnceHelper() = default; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/ring_buffer/ring_buffer.cc b/streaming/src/ring_buffer/ring_buffer.cc new file mode 100644 index 00000000..b9c1d213 --- /dev/null +++ b/streaming/src/ring_buffer/ring_buffer.cc @@ -0,0 +1,78 @@ +#include "ring_buffer/ring_buffer.h" + +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +StreamingRingBuffer::StreamingRingBuffer(size_t buf_size, + StreamingRingBufferType buffer_type) { + switch (buffer_type) { + case StreamingRingBufferType::SPSC: + message_buffer_ = + std::make_shared>(buf_size); + break; + case StreamingRingBufferType::SPSC_LOCK: + default: + message_buffer_ = + std::make_shared>(buf_size); + } +} + +bool StreamingRingBuffer::Push(const StreamingMessagePtr &msg) { + message_buffer_->Push(msg); + return true; +} + +StreamingMessagePtr &StreamingRingBuffer::Front() { + STREAMING_CHECK(!message_buffer_->Empty()); + return message_buffer_->Front(); +} + +void StreamingRingBuffer::Pop() { + STREAMING_CHECK(!message_buffer_->Empty()); + message_buffer_->Pop(); +} + +bool StreamingRingBuffer::IsFull() const { return message_buffer_->Full(); } + +bool StreamingRingBuffer::IsEmpty() const { return message_buffer_->Empty(); } + +size_t StreamingRingBuffer::Size() const { return message_buffer_->Size(); } + +size_t StreamingRingBuffer::Capacity() const { return message_buffer_->Capacity(); } + +size_t StreamingRingBuffer::GetTransientBufferSize() { + return transient_buffer_.GetTransientBufferSize(); +}; + +void StreamingRingBuffer::SetTransientBufferSize(uint32_t new_transient_buffer_size) { + return transient_buffer_.SetTransientBufferSize(new_transient_buffer_size); +} + +size_t StreamingRingBuffer::GetMaxTransientBufferSize() const { + return transient_buffer_.GetMaxTransientBufferSize(); +} + +const uint8_t *StreamingRingBuffer::GetTransientBuffer() const { + return transient_buffer_.GetTransientBuffer(); +} + +uint8_t *StreamingRingBuffer::GetTransientBufferMutable() const { + return transient_buffer_.GetTransientBufferMutable(); +} + +void StreamingRingBuffer::ReallocTransientBuffer(uint32_t size) { + transient_buffer_.ReallocTransientBuffer(size); +} + +bool StreamingRingBuffer::IsTransientAvaliable() { + return transient_buffer_.IsTransientAvaliable(); +} + +void StreamingRingBuffer::FreeTransientBuffer(bool is_force) { + transient_buffer_.FreeTransientBuffer(is_force); +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/ring_buffer/ring_buffer.h b/streaming/src/ring_buffer/ring_buffer.h new file mode 100644 index 00000000..3ae1ff51 --- /dev/null +++ b/streaming/src/ring_buffer/ring_buffer.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "message/message.h" +#include "ray/common/status.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Because the data cannot be successfully written to the channel every time, in +/// order not to serialize the message repeatedly, we designed a temporary buffer +/// area so that when the downstream is backpressured or the channel is blocked +/// due to memory limitations, it can be cached first and waited for the next use. +class StreamingTransientBuffer { + private: + std::shared_ptr transient_buffer_; + // BufferSize is length of last serialization data. + uint32_t transient_buffer_size_ = 0; + uint32_t max_transient_buffer_size_ = 0; + bool transient_flag_ = false; + + public: + inline size_t GetTransientBufferSize() const { return transient_buffer_size_; } + + inline void SetTransientBufferSize(uint32_t new_transient_buffer_size) { + transient_buffer_size_ = new_transient_buffer_size; + } + + inline size_t GetMaxTransientBufferSize() const { return max_transient_buffer_size_; } + + inline const uint8_t *GetTransientBuffer() const { return transient_buffer_.get(); } + + inline uint8_t *GetTransientBufferMutable() const { return transient_buffer_.get(); } + + /// To reuse transient buffer, we will realloc buffer memory if size of needed + /// message bundle raw data is greater-than original buffer size. + /// \param size buffer size + /// + inline void ReallocTransientBuffer(uint32_t size) { + transient_buffer_size_ = size; + transient_flag_ = true; + if (max_transient_buffer_size_ > size) { + return; + } + max_transient_buffer_size_ = size; + transient_buffer_.reset(new uint8_t[size], std::default_delete()); + } + + inline bool IsTransientAvaliable() { return transient_flag_; } + + inline void FreeTransientBuffer(bool is_force = false) { + transient_buffer_size_ = 0; + transient_flag_ = false; + + // Transient buffer always holds max size buffer among all messages, which is + // wasteful. So expiration time is considerable idea to release large buffer if this + // transient buffer pointer hold it in long time. + + if (is_force) { + max_transient_buffer_size_ = 0; + transient_buffer_.reset(); + } + } + + virtual ~StreamingTransientBuffer() = default; +}; + +template +class AbstractRingBuffer { + public: + virtual void Push(const T &) = 0; + virtual void Pop() = 0; + virtual T &Front() = 0; + virtual bool Empty() const = 0; + virtual bool Full() const = 0; + virtual size_t Size() const = 0; + virtual size_t Capacity() const = 0; +}; + +template +class RingBufferImplThreadSafe : public AbstractRingBuffer { + public: + RingBufferImplThreadSafe(size_t size) : buffer_(size) {} + virtual ~RingBufferImplThreadSafe() = default; + void Push(const T &t) { + boost::unique_lock lock(ring_buffer_mutex_); + buffer_.push_back(t); + } + void Pop() { + boost::unique_lock lock(ring_buffer_mutex_); + buffer_.pop_front(); + } + T &Front() { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.front(); + } + bool Empty() const { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.empty(); + } + bool Full() const { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.full(); + } + size_t Size() const { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.size(); + } + size_t Capacity() const { return buffer_.capacity(); } + + private: + mutable boost::shared_mutex ring_buffer_mutex_; + boost::circular_buffer buffer_; +}; + +template +class RingBufferImplLockFree : public AbstractRingBuffer { + private: + std::vector buffer_; + std::atomic capacity_; + std::atomic read_index_; + std::atomic write_index_; + + public: + RingBufferImplLockFree(size_t size) + : buffer_(size, nullptr), capacity_(size), read_index_(0), write_index_(0) {} + virtual ~RingBufferImplLockFree() = default; + + void Push(const T &t) { + STREAMING_CHECK(!Full()); + buffer_[write_index_] = t; + write_index_ = IncreaseIndex(write_index_); + } + + void Pop() { + STREAMING_CHECK(!Empty()); + read_index_ = IncreaseIndex(read_index_); + } + + T &Front() { + STREAMING_CHECK(!Empty()); + return buffer_[read_index_]; + } + + bool Empty() const { return write_index_ == read_index_; } + + bool Full() const { return IncreaseIndex(write_index_) == read_index_; } + + size_t Size() const { return (write_index_ + capacity_ - read_index_) % capacity_; } + + size_t Capacity() const { return capacity_; } + + private: + size_t IncreaseIndex(size_t index) const { return (index + 1) % capacity_; } +}; + +enum class StreamingRingBufferType : uint8_t { SPSC_LOCK, SPSC }; + +/// StreamingRinggBuffer is factory to generate two different buffers. In data +/// writer, we use lock-free single producer single consumer (SPSC) ring buffer +/// to hold messages from user thread because SPSC has much better performance +/// than lock style. Since the SPSC_LOCK is useful to our event-driver model( +/// we will use that buffer to optimize our thread model in the future), so +/// it cann't be removed currently. +class StreamingRingBuffer { + private: + std::shared_ptr> message_buffer_; + + StreamingTransientBuffer transient_buffer_; + + public: + explicit StreamingRingBuffer(size_t buf_size, StreamingRingBufferType buffer_type = + StreamingRingBufferType::SPSC_LOCK); + + bool Push(const StreamingMessagePtr &msg); + + StreamingMessagePtr &Front(); + + void Pop(); + + bool IsFull() const; + + bool IsEmpty() const; + + size_t Size() const; + + size_t Capacity() const; + + size_t GetTransientBufferSize(); + + void SetTransientBufferSize(uint32_t new_transient_buffer_size); + + size_t GetMaxTransientBufferSize() const; + + const uint8_t *GetTransientBuffer() const; + + uint8_t *GetTransientBufferMutable() const; + + void ReallocTransientBuffer(uint32_t size); + + bool IsTransientAvaliable(); + + void FreeTransientBuffer(bool is_force = false); +}; + +typedef std::shared_ptr StreamingRingBufferPtr; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/runtime_context.cc b/streaming/src/runtime_context.cc new file mode 100644 index 00000000..782d6f3d --- /dev/null +++ b/streaming/src/runtime_context.cc @@ -0,0 +1,137 @@ +#include "runtime_context.h" + +#include "ray/common/id.h" +#include "ray/util/util.h" +#include "src/ray/protobuf/common.pb.h" +#include "util/streaming_logging.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +void RuntimeContext::SetConfig(const StreamingConfig &streaming_config) { + STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init) + << "set config must be at beginning"; + config_ = streaming_config; +} + +void RuntimeContext::SetConfig(const uint8_t *data, uint32_t size) { + STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init) + << "set config must be at beginning"; + if (!data) { + STREAMING_LOG(WARNING) << "buffer pointer is null, but len is => " << size; + return; + } + config_.FromProto(data, size); +} + +RuntimeContext::~RuntimeContext() {} + +RuntimeContext::RuntimeContext() + : enable_timer_service_(false), runtime_status_(RuntimeStatus::Init) {} + +void RuntimeContext::InitMetricsReporter() { + STREAMING_LOG(INFO) << "init metrics"; + if (!config_.GetMetricsEnable()) { + STREAMING_LOG(WARNING) << "metrics is disable"; + return; + } + perf_metrics_reporter_.reset(new StreamingReporter()); + + std::unordered_map default_tag_map = { + {"role", NodeType_Name(config_.GetNodeType())}, + {"op_name", config_.GetOpName()}, + {"worker_name", config_.GetWorkerName()}}; + metrics_config_.SetMetricsGlobalTags(default_tag_map); + + perf_metrics_reporter_->Start(metrics_config_); +} + +void RuntimeContext::ReportMetrics( + const std::string &metric_name, double value, + const std::unordered_map &tags) { + if (config_.GetMetricsEnable()) { + perf_metrics_reporter_->UpdateGauge(metric_name, tags, value); + } +} + +void RuntimeContext::RunTimer() { + AutoSpinLock lock(report_flag_); + if (runtime_status_ != RuntimeStatus::Running) { + STREAMING_LOG(WARNING) << "Run timer failed in state " + << static_cast(runtime_status_); + return; + } + STREAMING_LOG(INFO) << "Streaming metric timer called, interval=" + << metrics_config_.GetMetricsReportInterval(); + if (async_io_.stopped()) { + STREAMING_LOG(INFO) << "Async io stopped, return from timer reporting."; + return; + } + this->report_timer_handler_(); + boost::posix_time::seconds interval(metrics_config_.GetMetricsReportInterval()); + metrics_timer_->expires_from_now(interval); + metrics_timer_->async_wait([this](const boost::system::error_code &e) { + if (boost::asio::error::operation_aborted == e) { + return; + } + this->RunTimer(); + }); +} + +void RuntimeContext::EnableTimer(std::function report_timer_handler) { + if (!config_.GetMetricsEnable()) { + STREAMING_LOG(WARNING) << "Streaming metrics disabled."; + return; + } + if (enable_timer_service_) { + STREAMING_LOG(INFO) << "Timer service already enabled"; + return; + } + this->report_timer_handler_ = report_timer_handler; + STREAMING_LOG(INFO) << "Streaming metric timer enabled"; + // We new a thread for timer if timer is not alive currently. + if (!timer_thread_) { + async_io_.reset(); + boost::posix_time::seconds interval(metrics_config_.GetMetricsReportInterval()); + metrics_timer_.reset(new boost::asio::deadline_timer(async_io_, interval)); + metrics_timer_->async_wait( + [this](const boost::system::error_code & /*e*/) { this->RunTimer(); }); + timer_thread_ = std::make_shared([this]() { + STREAMING_LOG(INFO) << "Async io running."; + async_io_.run(); + }); + STREAMING_LOG(INFO) << "New thread " << timer_thread_->get_id(); + } + enable_timer_service_ = true; +} + +void RuntimeContext::ShutdownTimer() { + { + AutoSpinLock lock(report_flag_); + if (!config_.GetMetricsEnable()) { + STREAMING_LOG(WARNING) << "Streaming metrics disabled"; + return; + } + if (!enable_timer_service_) { + STREAMING_LOG(INFO) << "Timer service already disabled"; + return; + } + STREAMING_LOG(INFO) << "Timer server shutdown"; + enable_timer_service_ = false; + STREAMING_LOG(INFO) << "Cancel metrics timer."; + metrics_timer_->cancel(); + } + STREAMING_LOG(INFO) << "Wake up all reporting conditions."; + if (timer_thread_) { + STREAMING_LOG(INFO) << "Join and reset timer thread."; + if (timer_thread_->joinable()) { + timer_thread_->join(); + } + timer_thread_.reset(); + metrics_timer_.reset(); + } +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/runtime_context.h b/streaming/src/runtime_context.h new file mode 100644 index 00000000..5a49d82b --- /dev/null +++ b/streaming/src/runtime_context.h @@ -0,0 +1,89 @@ +// Copyright 2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common/status.h" +#include "config/streaming_config.h" +#include "metrics/streaming_perf_metric.h" + +namespace ray { +namespace streaming { + +enum class RuntimeStatus : uint8_t { Init = 0, Running = 1, Interrupted = 2 }; + +#define RETURN_IF_NOT_OK(STATUS_EXP) \ + { \ + StreamingStatus state = STATUS_EXP; \ + if (StreamingStatus::OK != state) { \ + return state; \ + } \ + } + +class RuntimeContext { + public: + RuntimeContext(); + virtual ~RuntimeContext(); + inline const StreamingConfig &GetConfig() const { return config_; }; + void SetConfig(const StreamingConfig &config); + void SetConfig(const uint8_t *data, uint32_t buffer_len); + inline RuntimeStatus GetRuntimeStatus() { return runtime_status_; } + inline void SetRuntimeStatus(RuntimeStatus status) { runtime_status_ = status; } + inline void MarkMockTest() { is_mock_test_ = true; } + inline bool IsMockTest() { return is_mock_test_; } + + void InitMetricsReporter(); + + /// It's periodic reporter entry for all runtime modules. + /// \param metric_name, metric name + /// \param value, metric value + /// \param tags, metric tag map + void ReportMetrics(const std::string &metric_name, double value, + const std::unordered_map &tags = {}); + + /// Enable and register a specific reporting timer for updating all of metrics. + /// \param reporter_timer_handler + void EnableTimer(std::function report_timer_handler); + + /// Halt the timer invoking from now on. + void ShutdownTimer(); + + private: + void RunTimer(); + + protected: + std::unique_ptr perf_metrics_reporter_; + std::function report_timer_handler_; + + boost::asio::io_service async_io_; + + private: + bool enable_timer_service_; + + std::unique_ptr metrics_timer_; + std::shared_ptr timer_thread_; + std::atomic_flag report_flag_ = ATOMIC_FLAG_INIT; + + StreamingConfig config_; + RuntimeStatus runtime_status_; + StreamingMetricsConfig metrics_config_; + bool is_mock_test_ = false; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/test/barrier_helper_tests.cc b/streaming/src/test/barrier_helper_tests.cc new file mode 100644 index 00000000..35fa5434 --- /dev/null +++ b/streaming/src/test/barrier_helper_tests.cc @@ -0,0 +1,157 @@ +#include "gtest/gtest.h" +#include "reliability/barrier_helper.h" + +using namespace ray::streaming; +using namespace ray; + +class StreamingBarrierHelperTest : public ::testing::Test { + public: + void SetUp() { barrier_helper_.reset(new StreamingBarrierHelper()); } + void TearDown() { barrier_helper_.release(); } + + protected: + std::unique_ptr barrier_helper_; + const ObjectID random_id = ray::ObjectID::FromRandom(); + const ObjectID another_random_id = ray::ObjectID::FromRandom(); +}; + +TEST_F(StreamingBarrierHelperTest, MsgIdByBarrierId) { + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0); + uint64_t msg_id = 0; + uint64_t init_msg_id = 10; + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id)); + + barrier_helper_->SetMsgIdByBarrierId(random_id, 1, init_msg_id); + + ASSERT_EQ(StreamingStatus::QueueIdNotFound, + barrier_helper_->GetMsgIdByBarrierId(another_random_id, 1, msg_id)); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetMsgIdByBarrierId(random_id, 1, msg_id)); + ASSERT_EQ(init_msg_id, msg_id); + + barrier_helper_->SetMsgIdByBarrierId(random_id, 2, init_msg_id + 1); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetMsgIdByBarrierId(random_id, 2, msg_id)); + ASSERT_EQ(init_msg_id + 1, msg_id); + + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 2); + barrier_helper_->ReleaseBarrierMapById(1); + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 1); + barrier_helper_->ReleaseAllBarrierMap(); + ASSERT_EQ(barrier_helper_->GetBarrierMapSize(), 0); +} + +TEST_F(StreamingBarrierHelperTest, BarrierIdByLastMessageId) { + uint64_t barrier_id = 0; + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id)); + + barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 10); + + ASSERT_EQ( + StreamingStatus::QueueIdNotFound, + barrier_helper_->GetBarrierIdByLastMessageId(another_random_id, 1, barrier_id)); + + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id)); + ASSERT_EQ(barrier_id, 10); + + barrier_helper_->SetBarrierIdByLastMessageId(random_id, 1, 11); + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); + ASSERT_EQ(barrier_id, 10); + ASSERT_EQ(StreamingStatus::OK, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); + ASSERT_EQ(barrier_id, 11); + ASSERT_EQ(StreamingStatus::NoSuchItem, + barrier_helper_->GetBarrierIdByLastMessageId(random_id, 1, barrier_id, true)); +} + +TEST_F(StreamingBarrierHelperTest, CheckpointId) { + uint64_t checkpoint_id = static_cast(-1); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 0); + barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 2); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 2); + barrier_helper_->SetCurrentMaxCheckpointIdInQueue(random_id, 3); + barrier_helper_->GetCurrentMaxCheckpointIdInQueue(random_id, checkpoint_id); + ASSERT_EQ(checkpoint_id, 3); +} + +TEST(BarrierHelper, barrier_map_get_set) { + StreamingBarrierHelper barrier_helper; + ray::ObjectID channel_id = ray::ObjectID::FromRandom(); + uint64_t msg_id; + auto status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, msg_id); + EXPECT_TRUE(status == StreamingStatus::NoSuchItem); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0); + + msg_id = 1; + barrier_helper.SetMsgIdByBarrierId(channel_id, 0, msg_id); + + uint64_t fetched_msg_id; + status = barrier_helper.GetMsgIdByBarrierId(channel_id, 0, fetched_msg_id); + EXPECT_TRUE(status == StreamingStatus::OK); + EXPECT_TRUE(fetched_msg_id == msg_id); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1); + + uint64_t fetched_no_barrier_id; + status = barrier_helper.GetMsgIdByBarrierId(channel_id, 1, fetched_no_barrier_id); + EXPECT_TRUE(status == StreamingStatus::NoSuchItem); + + ray::ObjectID other_channel_id = ray::ObjectID::FromRandom(); + status = barrier_helper.GetMsgIdByBarrierId(other_channel_id, 0, fetched_msg_id); + EXPECT_TRUE(status == StreamingStatus::QueueIdNotFound); + + EXPECT_TRUE(barrier_helper.Contains(0)); + EXPECT_TRUE(!barrier_helper.Contains(1)); + + msg_id = 10; + barrier_helper.SetMsgIdByBarrierId(channel_id, 1, msg_id); + EXPECT_TRUE(barrier_helper.Contains(1)); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 2); + + barrier_helper.ReleaseBarrierMapById(0); + EXPECT_TRUE(!barrier_helper.Contains(0)); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 1); + + msg_id = 20; + barrier_helper.SetMsgIdByBarrierId(channel_id, 2, msg_id); + std::vector barrier_id_vec; + barrier_helper.GetAllBarrier(barrier_id_vec); + EXPECT_TRUE(barrier_id_vec.size() == 2); + barrier_helper.ReleaseAllBarrierMap(); + EXPECT_TRUE(barrier_helper.GetBarrierMapSize() == 0); +} + +TEST(BarrierHelper, barrier_checkpoint_mapping) { + StreamingBarrierHelper barrier_helper; + ray::ObjectID channel_id = ray::ObjectID::FromRandom(); + uint64_t msg_id = 1; + uint64_t barrier_id = 0; + barrier_helper.SetMsgIdByBarrierId(channel_id, barrier_id, msg_id); + uint64_t checkpoint_id = 100; + barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id); + uint64_t fetched_checkpoint_id; + barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id); + EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id); + + barrier_id = 2; + barrier_helper.MapBarrierToCheckpoint(barrier_id, checkpoint_id); + barrier_helper.GetCheckpointIdByBarrierId(barrier_id, fetched_checkpoint_id); + EXPECT_TRUE(fetched_checkpoint_id == checkpoint_id); + barrier_helper.ReleaseBarrierMapCheckpointByBarrierId(barrier_id); + + auto status1 = barrier_helper.GetCheckpointIdByBarrierId(0, fetched_checkpoint_id); + auto status2 = barrier_helper.GetCheckpointIdByBarrierId(2, fetched_checkpoint_id); + EXPECT_TRUE(status1 == status2 && status1 == StreamingStatus::NoSuchItem); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/data_writer_tests.cc b/streaming/src/test/data_writer_tests.cc new file mode 100644 index 00000000..61594e98 --- /dev/null +++ b/streaming/src/test/data_writer_tests.cc @@ -0,0 +1,138 @@ +#include "data_writer.h" +#include "gtest/gtest.h" + +namespace ray { +namespace streaming { +void GenRandomChannelIdVector(std::vector &input_ids, int n) { + for (int i = 0; i < n; ++i) { + input_ids.push_back(ObjectID::FromRandom()); + } +} + +class MockWriter : public DataWriter { + public: + friend class MockWriterTest; + MockWriter(std::shared_ptr runtime_context) + : DataWriter(runtime_context) {} + void Init(const std::vector &input_channel_vec) { + output_queue_ids_ = input_channel_vec; + for (size_t i = 0; i < input_channel_vec.size(); ++i) { + const ChannelCreationParameter param; + InitChannel(input_channel_vec[i], param, 0, 0xfff); + } + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, this, nullptr); + event_service_ = std::make_shared(); + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + event_service_->Run(); + } + + void Destroy() { + event_service_->Stop(); + event_service_.reset(); + } + + bool IsMessageAvailableInBuffer(const ObjectID &id) { + return DataWriter::IsMessageAvailableInBuffer(channel_info_map_[id]); + } + + std::unordered_map &GetChannelInfoMap() { + return channel_info_map_; + }; + + bool CollectFromRingBuffer(const ObjectID &id, uint64_t &buffer_remain) { + return DataWriter::CollectFromRingBuffer(channel_info_map_[id], buffer_remain); + } + + StreamingStatus WriteBufferToChannel(const ObjectID &id, uint64_t &buffer_remain) { + return DataWriter::WriteBufferToChannel(channel_info_map_[id], buffer_remain); + } + + void BroadcastBarrier(uint64_t barrier_id) { + static const uint8_t barrier_data[] = {1, 2, 3, 4}; + DataWriter::BroadcastBarrier(barrier_id, barrier_data, 4); + } + + uint64_t WriteMessageToBufferRing(const ObjectID &channel_id, uint8_t *data, + uint32_t data_size) { + return DataWriter::WriteMessageToBufferRing(channel_id, data, data_size); + } +}; + +class MockWriterTest : public ::testing::Test { + protected: + virtual void SetUp() override { + runtime_context.reset(new RuntimeContext()); + runtime_context->SetConfig(config); + runtime_context->MarkMockTest(); + mock_writer.reset(new MockWriter(runtime_context)); + } + virtual void TearDown() override { mock_writer->Destroy(); } + + protected: + std::shared_ptr runtime_context; + StreamingConfig config; + std::shared_ptr mock_writer; + std::vector input_ids; +}; + +TEST_F(MockWriterTest, test_message_avaliablie_in_buffer) { + int channel_num = 5; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + for (const auto &id : input_ids) { + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(id)); + } + mock_writer->BroadcastBarrier(0); + for (const auto &id : input_ids) { + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(id)); + } +} + +uint8_t data[] = {0x01, 0x02, 0x0f, 0xe, 0x00}; +uint32_t data_size = 5; + +TEST_F(MockWriterTest, test_write_message_to_buffer_ring) { + int channel_num = 2; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + for (const auto &id : input_ids) { + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(id)); + } + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(input_ids[0])); + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(input_ids[1])); +} + +TEST_F(MockWriterTest, test_collecting_buffer) { + int channel_num = 1; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + uint64_t buffer_remain; + mock_writer->CollectFromRingBuffer(input_ids[0], buffer_remain); + EXPECT_TRUE(buffer_remain == 0); + EXPECT_TRUE(mock_writer->IsMessageAvailableInBuffer(input_ids[0])); + EXPECT_TRUE(mock_writer->GetChannelInfoMap()[input_ids[0]] + .writer_ring_buffer->IsTransientAvaliable()); +} + +TEST_F(MockWriterTest, test_write_to_transfer) { + int channel_num = 1; + GenRandomChannelIdVector(input_ids, channel_num); + mock_writer->Init(input_ids); + mock_writer->WriteMessageToBufferRing(input_ids[0], data, data_size); + uint64_t buffer_remain; + EXPECT_EQ(mock_writer->WriteBufferToChannel(input_ids[0], buffer_remain), + StreamingStatus::OK); + EXPECT_TRUE(buffer_remain == 0); + EXPECT_TRUE(!mock_writer->IsMessageAvailableInBuffer(input_ids[0])); +} + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/event_service_tests.cc b/streaming/src/test/event_service_tests.cc new file mode 100644 index 00000000..5775cb24 --- /dev/null +++ b/streaming/src/test/event_service_tests.cc @@ -0,0 +1,93 @@ +#include + +#include "event_service.h" +#include "gtest/gtest.h" + +using namespace ray::streaming; + +/// Mock function for send empty message. +bool SendEmptyToChannel(ProducerChannelInfo *info) { return true; } + +/// Mock function for write all messages to channel. +bool WriteAllToChannel(ProducerChannelInfo *info) { return true; } + +TEST(EventServiceTest, Test1) { + std::shared_ptr server = std::make_shared(); + + ProducerChannelInfo mock_channel_info; + server->Register(EventType::EmptyEvent, SendEmptyToChannel); + server->Register(EventType::UserEvent, WriteAllToChannel); + server->Register(EventType::FlowEvent, WriteAllToChannel); + + bool stop = false; + std::thread thread_empty([server, &mock_channel_info, &stop] { + std::chrono::milliseconds MockTimer(20); + while (!stop) { + Event event(&mock_channel_info, EventType::EmptyEvent, true); + server->Push(event); + std::this_thread::sleep_for(MockTimer); + } + }); + + std::thread thread_flow([server, &mock_channel_info, &stop] { + std::chrono::milliseconds MockTimer(2); + while (!stop) { + Event event(&mock_channel_info, EventType::FlowEvent, true); + server->Push(event); + std::this_thread::sleep_for(MockTimer); + } + }); + + std::thread thread_user([server, &mock_channel_info, &stop] { + std::chrono::milliseconds MockTimer(2); + while (!stop) { + Event event(&mock_channel_info, EventType::UserEvent, true); + server->Push(event); + std::this_thread::sleep_for(MockTimer); + } + }); + + server->Run(); + + std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000)); + + STREAMING_LOG(INFO) << "5 seconds passed."; + STREAMING_LOG(INFO) << "EventNums: " << server->EventNums(); + stop = true; + STREAMING_LOG(INFO) << "Stop"; + server->Stop(); + thread_empty.join(); + thread_flow.join(); + thread_user.join(); +} + +TEST(EventServiceTest, remove_delete_channel_event) { + std::shared_ptr server = std::make_shared(); + + std::vector channel_vec; + std::vector mock_channel_info_vec; + channel_vec.push_back(ObjectID::FromRandom()); + ProducerChannelInfo mock_channel_info1; + mock_channel_info1.channel_id = channel_vec.back(); + mock_channel_info_vec.push_back(mock_channel_info1); + ProducerChannelInfo mock_channel_info2; + channel_vec.push_back(ObjectID::FromRandom()); + mock_channel_info2.channel_id = channel_vec.back(); + mock_channel_info_vec.push_back(mock_channel_info2); + + for (auto &id : mock_channel_info_vec) { + Event empty_event(&id, EventType::EmptyEvent, true); + Event user_event(&id, EventType::UserEvent, true); + Event flow_event(&id, EventType::FlowEvent, true); + server->Push(empty_event); + server->Push(user_event); + server->Push(flow_event); + } + std::vector removed_vec(channel_vec.begin(), channel_vec.begin() + 1); + server->RemoveDestroyedChannelEvent(removed_vec); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/message_serialization_tests.cc b/streaming/src/test/message_serialization_tests.cc new file mode 100644 index 00000000..f38fc223 --- /dev/null +++ b/streaming/src/test/message_serialization_tests.cc @@ -0,0 +1,175 @@ +#include +#include + +#include "gtest/gtest.h" +#include "message/message.h" +#include "message/message_bundle.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingSerializationTest, streaming_message_serialization_test) { + uint8_t data[] = {9, 1, 3}; + StreamingMessagePtr message = + std::make_shared(data, 3, 7, StreamingMessageType::Message); + uint32_t message_length = message->ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + message->ToBytes(bytes); + StreamingMessagePtr new_message = StreamingMessage::FromBytes(bytes); + EXPECT_EQ(std::memcmp(new_message->Payload(), data, 3), 0); + delete[] bytes; +} + +TEST(StreamingSerializationTest, streaming_message_empty_bundle_serialization_test) { + for (int i = 0; i < 10; ++i) { + StreamingMessageBundle bundle(i, i); + uint64_t bundle_size = bundle.ClassBytesSize(); + uint8_t *bundle_bytes = new uint8_t[bundle_size]; + bundle.ToBytes(bundle_bytes); + StreamingMessageBundlePtr bundle_ptr = + StreamingMessageBundle::FromBytes(bundle_bytes); + + EXPECT_EQ(bundle.ClassBytesSize(), bundle_ptr->ClassBytesSize()); + EXPECT_EQ(bundle.GetMessageListSize(), bundle_ptr->GetMessageListSize()); + EXPECT_EQ(bundle.GetBundleType(), bundle_ptr->GetBundleType()); + EXPECT_EQ(bundle.GetLastMessageId(), bundle_ptr->GetLastMessageId()); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + std::list b_message_list; + bundle.GetMessageList(b_message_list); + EXPECT_EQ(b_message_list.size(), 0); + EXPECT_EQ(s_message_list.size(), 0); + + delete[] bundle_bytes; + } +} +TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_test) { + for (int i = 0; i < 10; ++i) { + uint8_t data[] = {1, 2, 3, 4}; + uint32_t data_size = 4; + uint32_t head_size = sizeof(uint64_t); + uint64_t checkpoint_id = 777; + std::shared_ptr ptr(new uint8_t[data_size + head_size], + std::default_delete()); + // move checkpint_id in head of barrier data + std::memcpy(ptr.get(), &checkpoint_id, head_size); + std::memcpy(ptr.get() + head_size, data, data_size); + StreamingMessagePtr message = std::make_shared( + data, head_size + data_size, i, StreamingMessageType::Barrier); + std::list message_list; + message_list.push_back(message); + // message list will be moved to bundle member + std::list message_list_cpy(message_list); + + StreamingMessageBundle bundle(message_list_cpy, i, i, + StreamingMessageBundleType::Barrier); + uint64_t bundle_size = bundle.ClassBytesSize(); + uint8_t *bundle_bytes = new uint8_t[bundle_size]; + bundle.ToBytes(bundle_bytes); + StreamingMessageBundlePtr bundle_ptr = + StreamingMessageBundle::FromBytes(bundle_bytes); + + EXPECT_TRUE(bundle.ClassBytesSize() == bundle_ptr->ClassBytesSize()); + EXPECT_TRUE(bundle.GetMessageListSize() == bundle_ptr->GetMessageListSize()); + EXPECT_TRUE(bundle.GetBundleType() == bundle_ptr->GetBundleType()); + EXPECT_TRUE(bundle.GetLastMessageId() == bundle_ptr->GetLastMessageId()); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(s_message_list.size() == message_list.size()); + auto m_item = message_list.back(); + auto s_item = s_message_list.back(); + EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize()); + EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType()); + EXPECT_TRUE(s_item->GetMessageId() == m_item->GetMessageId()); + EXPECT_TRUE(s_item->PayloadSize() == m_item->PayloadSize()); + EXPECT_TRUE( + std::memcmp(s_item->Payload(), m_item->Payload(), m_item->PayloadSize()) == 0); + EXPECT_TRUE(*(s_item.get()) == (*(m_item.get()))); + + delete[] bundle_bytes; + } +} + +TEST(StreamingSerializationTest, streaming_message_bundle_serialization_test) { + for (int k = 0; k <= 1000; k++) { + std::list message_list; + + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + data[0] = i; + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list.push_back(message); + delete[] data; + } + StreamingMessageBundle messageBundle(message_list, 0, 1, + StreamingMessageBundleType::Bundle); + size_t message_length = messageBundle.ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + messageBundle.ToBytes(bytes); + + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes); + EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(bundle_ptr->operator==(messageBundle)); + StreamingMessageBundleMetaPtr bundle_meta_ptr = + StreamingMessageBundleMeta::FromBytes(bytes); + + EXPECT_EQ(bundle_meta_ptr->GetBundleType(), bundle_ptr->GetBundleType()); + EXPECT_EQ(bundle_meta_ptr->GetLastMessageId(), bundle_ptr->GetLastMessageId()); + EXPECT_EQ(bundle_meta_ptr->GetMessageBundleTs(), bundle_ptr->GetMessageBundleTs()); + EXPECT_EQ(bundle_meta_ptr->GetMessageListSize(), bundle_ptr->GetMessageListSize()); + delete[] bytes; + } +} + +TEST(StreamingSerializationTest, streaming_message_bundle_equal_test) { + std::list message_list; + std::list message_list_same; + std::list message_list_cpy; + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + for (int j = 0; j < i + 1; ++j) { + data[j] = i; + } + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list.push_back(message); + message_list_cpy.push_front(message); + delete[] data; + } + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + for (int j = 0; j < i + 1; ++j) { + data[j] = i; + } + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list_same.push_back(message); + delete[] data; + } + StreamingMessageBundle message_bundle(message_list, 0, 1, + StreamingMessageBundleType::Bundle); + StreamingMessageBundle message_bundle_same(message_list_same, 0, 1, + StreamingMessageBundleType::Bundle); + StreamingMessageBundle message_bundle_reverse(message_list_cpy, 0, 1, + StreamingMessageBundleType::Bundle); + EXPECT_TRUE(message_bundle_same == message_bundle); + EXPECT_FALSE(message_bundle_reverse == message_bundle); + size_t message_length = message_bundle.ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + message_bundle.ToBytes(bytes); + + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes); + EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(bundle_ptr->operator==(message_bundle)); + delete[] bytes; +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc new file mode 100644 index 00000000..3d5fda65 --- /dev/null +++ b/streaming/src/test/mock_actor.cc @@ -0,0 +1,658 @@ +#define BOOST_BIND_NO_PLACEHOLDERS +#include "common/status.h" +#include "data_reader.h" +#include "data_writer.h" +#include "gtest/gtest.h" +#include "message/message.h" +#include "message/message_bundle.h" +#include "queue/queue_client.h" +#include "ray/common/test_util.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ring_buffer/ring_buffer.h" +using namespace std::placeholders; + +const uint32_t MESSAGE_BOUND_SIZE = 10000; +const uint32_t DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE = 1000; + +namespace ray { +namespace streaming { + +class StreamingQueueTestSuite { + public: + StreamingQueueTestSuite(ActorID &peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : peer_actor_id_(peer_actor_id), + queue_ids_(queue_ids), + rescale_queue_ids_(rescale_queue_ids) {} + + virtual void ExecuteTest(std::string test_name) { + auto it = test_func_map_.find(test_name); + STREAMING_CHECK(it != test_func_map_.end()); + current_test_ = test_name; + status_ = false; + auto func = it->second; + executor_thread_ = std::make_shared(func); + executor_thread_->detach(); + } + + virtual std::shared_ptr CheckCurTestStatus() { + TestCheckStatusRspMsg msg(current_test_, status_); + return msg.ToBytes(); + } + + virtual bool TestDone() { return status_; } + + virtual ~StreamingQueueTestSuite() {} + + protected: + std::unordered_map> test_func_map_; + std::string current_test_; + bool status_; + std::shared_ptr executor_thread_; + ActorID peer_actor_id_; + std::vector queue_ids_; + std::vector rescale_queue_ids_; +}; + +class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueWriterTestSuite(ActorID &peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"streaming_writer_exactly_once_test", + std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest, + this)}}; + } + + private: + void StreamingWriterExactlyOnceTest() { + StreamingConfig config; + StreamingWriterStrategyTest(config); + + STREAMING_LOG(INFO) + << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; + status_ = true; + } + + void StreamingWriterStrategyTest(StreamingConfig &config) { + for (auto &queue_id : queue_ids_) { + STREAMING_LOG(INFO) << "queue_id: " << queue_id; + } + ChannelCreationParameter param{ + peer_actor_id_, + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})), + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""}))}; + std::vector params(queue_ids_.size(), param); + STREAMING_LOG(INFO) << "writer actor_ids size: " << params.size() + << " actor_id: " << peer_actor_id_; + + std::shared_ptr runtime_context(new RuntimeContext()); + runtime_context->SetConfig(config); + + // Create writer. + std::shared_ptr streaming_writer_client(new DataWriter(runtime_context)); + uint64_t queue_size = 10 * 1000 * 1000; + std::vector channel_seq_id_vec(queue_ids_.size(), 0); + streaming_writer_client->Init(queue_ids_, params, channel_seq_id_vec, + std::vector(queue_ids_.size(), queue_size)); + STREAMING_LOG(INFO) << "streaming_writer_client Init done"; + + streaming_writer_client->Run(); + + // Write some data. + std::thread test_loop_thread( + &StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this, + streaming_writer_client, std::ref(queue_ids_)); + if (test_loop_thread.joinable()) { + test_loop_thread.join(); + } + } + + void TestWriteMessageToBufferRing(std::shared_ptr writer_client, + std::vector &q_list) { + uint32_t i = 1; + while (i <= MESSAGE_BOUND_SIZE) { + for (auto &q_id : q_list) { + uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); + uint8_t *data = new uint8_t[buffer_len]; + for (uint32_t j = 0; j < buffer_len; ++j) { + data[j] = j % 128; + } + STREAMING_LOG(DEBUG) << "Write data to queue, count=" << i + << ", queue_id=" << q_id; + writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, + StreamingMessageType::Message); + if (i % 10 == 0) { + writer_client->BroadcastBarrier(i / 10, nullptr, 0); + } + } + ++i; + } + STREAMING_LOG(INFO) << "Write data done."; + // Wait a while. + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + } +}; + +class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueReaderTestSuite(ActorID peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"streaming_writer_exactly_once_test", + std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest, + this)}}; + } + + private: + void ReaderLoopForward(std::shared_ptr reader_client, + std::shared_ptr writer_client, + std::vector &queue_id_vec) { + uint64_t recevied_message_cnt = 0; + std::unordered_map queue_last_cp_id; + + for (auto &q_id : queue_id_vec) { + queue_last_cp_id[q_id] = 0; + } + STREAMING_LOG(INFO) << "Start read message bundle, queue_id_size=" + << queue_id_vec.size(); + while (true) { + std::shared_ptr msg; + StreamingStatus st = reader_client->GetBundle(100, msg); + + if (st != StreamingStatus::OK || !msg->data) { + STREAMING_LOG(DEBUG) << "read bundle timeout, status = " << (int)st; + continue; + } + + STREAMING_CHECK(msg.get() && msg->meta.get()) + << "read null pointer message, queue id => " << msg->from.Hex(); + + if (msg->meta->GetBundleType() == StreamingMessageBundleType::Barrier) { + StreamingBarrierHeader barrier_header; + StreamingMessage::GetBarrierIdFromRawData(msg->data + kMessageHeaderSize, + &barrier_header); + STREAMING_LOG(DEBUG) << "barrier message recevied, time=" + << msg->meta->GetMessageBundleTs() + << ", barrier_id=" << barrier_header.barrier_id + << ", data=" << Util::Byte2hex(msg->data, msg->data_size); + std::unordered_map *offset_map; + reader_client->GetOffsetInfo(offset_map); + + for (auto &q_id : queue_id_vec) { + reader_client->NotifyConsumedItem((*offset_map)[q_id], + (*offset_map)[q_id].current_message_id); + } + // writer_client->ClearCheckpoint(msg->last_barrier_id); + + continue; + } else if (msg->meta->GetBundleType() == StreamingMessageBundleType::Empty) { + STREAMING_LOG(DEBUG) << "empty message recevied => " + << msg->meta->GetMessageBundleTs(); + continue; + } + + StreamingMessageBundlePtr bundlePtr; + bundlePtr = StreamingMessageBundle::FromBytes(msg->data); + std::list message_list; + bundlePtr->GetMessageList(message_list); + STREAMING_LOG(INFO) << "message size => " << message_list.size() + << " from queue id => " << msg->from.Hex() + << " last message id => " << msg->meta->GetLastMessageId(); + + recevied_message_cnt += message_list.size(); + for (auto &item : message_list) { + uint64_t i = item->GetMessageId(); + + uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE; + if (i > MESSAGE_BOUND_SIZE) break; + + EXPECT_EQ(buff_len, item->PayloadSize()); + uint8_t *compared_data = new uint8_t[buff_len]; + for (uint32_t j = 0; j < item->PayloadSize(); ++j) { + compared_data[j] = j % 128; + } + EXPECT_EQ(std::memcmp(compared_data, item->Payload(), item->PayloadSize()), 0); + delete[] compared_data; + } + STREAMING_LOG(DEBUG) << "Received message count => " << recevied_message_cnt; + if (recevied_message_cnt == queue_id_vec.size() * MESSAGE_BOUND_SIZE) { + STREAMING_LOG(INFO) << "recevied message count => " << recevied_message_cnt + << ", break"; + break; + } + } + } + + void StreamingReaderStrategyTest(StreamingConfig &config) { + ChannelCreationParameter param{ + peer_actor_id_, + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})), + std::make_shared( + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""}))}; + std::vector params(queue_ids_.size(), param); + STREAMING_LOG(INFO) << "reader actor_ids size: " << params.size() + << " actor_id: " << peer_actor_id_; + std::shared_ptr runtime_context(new RuntimeContext()); + runtime_context->SetConfig(config); + std::shared_ptr reader(new DataReader(runtime_context)); + + reader->Init(queue_ids_, params, -1); + ReaderLoopForward(reader, nullptr, queue_ids_); + + STREAMING_LOG(INFO) << "Reader exit"; + } + + void StreamingWriterExactlyOnceTest() { + STREAMING_LOG(INFO) + << "StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest"; + StreamingConfig config; + + StreamingReaderStrategyTest(config); + status_ = true; + } +}; + +class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueUpStreamTestSuite(ActorID &peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"pull_peer_async_test", + std::bind(&StreamingQueueUpStreamTestSuite::PullPeerAsyncTest, this)}, + {"get_queue_test", + std::bind(&StreamingQueueUpStreamTestSuite::GetQueueTest, this)}}; + } + + void GetQueueTest() { + // Sleep 2s, queue shoulde not exist when reader pull. + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})}; + RayFunction sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""})}; + upstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, + sync_call_func); + upstream_handler->CreateUpstreamQueue(queue_id, peer_actor_id_, 10240); + STREAMING_LOG(INFO) << "IsQueueExist: " + << upstream_handler->UpstreamQueueExists(queue_id); + + // Sleep 2s, No valid data when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + std::this_thread::sleep_for(std::chrono::milliseconds(10 * 1000)); + STREAMING_LOG(INFO) << "StreamingQueueUpStreamTestSuite::GetQueueTest done"; + status_ = true; + } + + void PullPeerAsyncTest() { + // Sleep 2s, queue should not exist when reader pull. + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_async_call_func", ""})}; + RayFunction sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "reader_sync_call_func", ""})}; + upstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, + sync_call_func); + std::shared_ptr queue = + upstream_handler->CreateUpstreamQueue(queue_id, peer_actor_id_, 10240); + STREAMING_LOG(INFO) << "IsQueueExist: " + << upstream_handler->UpstreamQueueExists(queue_id); + + // Sleep 2s, No valid data when reader pull + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + // message id starts from 1 + for (int msg_id = 1; msg_id <= 80; msg_id++) { + uint8_t data[100]; + memset(data, msg_id, 100); + STREAMING_LOG(INFO) << "Writer User Push item msg_id: " << msg_id; + ASSERT_TRUE( + queue->Push(data, 100, current_sys_time_ms(), msg_id, msg_id, true).ok()); + queue->Send(); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + STREAMING_LOG(INFO) << "StreamingQueueUpStreamTestSuite::PullPeerAsyncTest done"; + status_ = true; + } +}; + +class StreamingQueueDownStreamTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueDownStreamTestSuite(ActorID peer_actor_id, + std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { + test_func_map_ = { + {"pull_peer_async_test", + std::bind(&StreamingQueueDownStreamTestSuite::PullPeerAsyncTest, this)}, + {"get_queue_test", + std::bind(&StreamingQueueDownStreamTestSuite::GetQueueTest, this)}}; + }; + + void GetQueueTest() { + auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})}; + RayFunction sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""})}; + downstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, + sync_call_func); + downstream_handler->CreateDownstreamQueue(queue_id, peer_actor_id_); + + bool is_upstream_first_pull_ = false; + downstream_handler->PullQueue(queue_id, 1, is_upstream_first_pull_, 10 * 1000); + ASSERT_TRUE(is_upstream_first_pull_); + downstream_handler->PullQueue(queue_id, 1, is_upstream_first_pull_, 10 * 1000); + ASSERT_FALSE(is_upstream_first_pull_); + STREAMING_LOG(INFO) << "StreamingQueueDownStreamTestSuite::GetQueueTest done"; + status_ = true; + } + + void PullPeerAsyncTest() { + auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService(); + ObjectID &queue_id = queue_ids_[0]; + RayFunction async_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_async_call_func", ""})}; + RayFunction sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::FromVector( + ray::Language::PYTHON, {"", "", "writer_sync_call_func", ""})}; + downstream_handler->SetPeerActorID(queue_id, peer_actor_id_, async_call_func, + sync_call_func); + std::shared_ptr queue = + downstream_handler->CreateDownstreamQueue(queue_id, peer_actor_id_); + + bool is_first_pull; + downstream_handler->PullQueue(queue_id, 1, is_first_pull, 10 * 1000); + uint64_t count = 0; + uint8_t msg_id = 1; + while (true) { + uint8_t *data = nullptr; + uint32_t data_size = 0; + uint64_t timeout_ms = 1000; + QueueItem item = queue->PopPendingBlockTimeout(timeout_ms * 1000); + if (item.SeqId() == QUEUE_INVALID_SEQ_ID) { + STREAMING_LOG(INFO) << "PopPendingBlockTimeout timeout."; + data = nullptr; + data_size = 0; + } else { + data = item.Buffer()->Data(); + data_size = item.Buffer()->Size(); + } + + STREAMING_LOG(INFO) << "[Reader] count: " << count; + if (data == nullptr) { + STREAMING_LOG(INFO) << "[Reader] data null"; + continue; + } + + for (uint32_t i = 0; i < data_size; i++) { + ASSERT_EQ(data[i], msg_id); + } + + count++; + if (count == 80) { + bool is_upstream_first_pull; + msg_id = 50; + downstream_handler->PullPeerAsync(queue_id, 50, is_upstream_first_pull, 1000); + continue; + } + + msg_id++; + STREAMING_LOG(INFO) << "[Reader] count: " << count; + if (count == 110) { + break; + } + } + + STREAMING_LOG(INFO) << "StreamingQueueDownStreamTestSuite::PullPeerAsyncTest done"; + status_ = true; + } +}; + +class TestSuiteFactory { + public: + static std::shared_ptr CreateTestSuite( + std::shared_ptr message) { + std::shared_ptr test_suite = nullptr; + std::string suite_name = message->TestSuiteName(); + queue::protobuf::StreamingQueueTestRole role = message->Role(); + const std::vector &queue_ids = message->QueueIds(); + const std::vector &rescale_queue_ids = message->RescaleQueueIds(); + ActorID peer_actor_id = message->PeerActorId(); + + if (role == queue::protobuf::StreamingQueueTestRole::WRITER) { + if (suite_name == "StreamingWriterTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); + } else if (suite_name == "StreamingQueueTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); + } else { + STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name; + } + } else { + if (suite_name == "StreamingWriterTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); + } else if (suite_name == "StreamingQueueTest") { + test_suite = std::make_shared( + peer_actor_id, queue_ids, rescale_queue_ids); + } else { + STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name; + } + } + + return test_suite; + } +}; + +class StreamingWorker { + public: + StreamingWorker(const std::string &store_socket, const std::string &raylet_socket, + int node_manager_port, const gcs::GcsClientOptions &gcs_options, + StartupToken startup_token) + : test_suite_(nullptr), peer_actor_handle_(nullptr) { + // You must keep it same with `src/ray/core_worker/core_worker.h:CoreWorkerOptions` + CoreWorkerOptions options; + options.worker_type = WorkerType::WORKER; + options.language = Language::PYTHON; + options.store_socket = store_socket; + options.raylet_socket = raylet_socket; + options.gcs_options = gcs_options; + options.enable_logging = true; + options.install_failure_signal_handler = true; + options.node_ip_address = "127.0.0.1"; + options.node_manager_port = node_manager_port; + options.raylet_ip_address = "127.0.0.1"; + options.task_execution_callback = std::bind(&StreamingWorker::ExecuteTask, this, _1, + _2, _3, _4, _5, _6, _7, _8, _9); + options.num_workers = 1; + options.metrics_agent_port = -1; + options.startup_token = startup_token; + CoreWorkerProcess::Initialize(options); + STREAMING_LOG(INFO) << "StreamingWorker constructor"; + } + + void RunTaskExecutionLoop() { + // Start executing tasks. + CoreWorkerProcess::RunTaskExecutionLoop(); + } + + private: + Status ExecuteTask(TaskType task_type, const std::string task_name, + const RayFunction &ray_function, + const std::unordered_map &required_resources, + const std::vector> &args, + const std::vector &arg_refs, + const std::vector &return_ids, + const std::string &debugger_breakpoint, + std::vector> *results) { + // Only one arg param used in streaming. + STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size(); + + ray::FunctionDescriptor function_descriptor = ray_function.GetFunctionDescriptor(); + RAY_CHECK(function_descriptor->Type() == + ray::FunctionDescriptorType::kPythonFunctionDescriptor); + auto typed_descriptor = function_descriptor->As(); + STREAMING_LOG(DEBUG) << "StreamingWorker::ExecuteTask " + << typed_descriptor->ToString(); + + std::string func_name = typed_descriptor->FunctionName(); + if (func_name == "init") { + std::shared_ptr local_buffer = + std::make_shared(args[0]->GetData()->Data(), + args[0]->GetData()->Size(), true); + HandleInitTask(local_buffer); + } else if (func_name == "execute_test") { + STREAMING_LOG(INFO) << "Test name: " << typed_descriptor->ClassName(); + test_suite_->ExecuteTest(typed_descriptor->ClassName()); + } else if (func_name == "check_current_test_status") { + results->push_back( + std::make_shared(test_suite_->CheckCurTestStatus(), nullptr, + std::vector())); + } else if (func_name == "reader_sync_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + auto result_buffer = reader_client_->OnReaderMessageSync(local_buffer); + results->push_back(std::make_shared( + result_buffer, nullptr, std::vector())); + } else if (func_name == "reader_async_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + reader_client_->OnReaderMessage(local_buffer); + } else if (func_name == "writer_sync_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + auto result_buffer = writer_client_->OnWriterMessageSync(local_buffer); + results->push_back(std::make_shared( + result_buffer, nullptr, std::vector())); + } else if (func_name == "writer_async_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + writer_client_->OnWriterMessage(local_buffer); + } else { + STREAMING_LOG(WARNING) << "Invalid function name " << func_name; + } + + return Status::OK(); + } + + private: + void HandleInitTask(std::shared_ptr buffer) { + reader_client_ = std::make_shared(); + writer_client_ = std::make_shared(); + uint8_t *bytes = buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum); + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + STREAMING_CHECK( + *type == + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType); + std::shared_ptr message = TestInitMessage::FromBytes(bytes); + + std::string actor_handle_serialized = message->ActorHandleSerialized(); + CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle( + actor_handle_serialized, ObjectID::Nil()); + std::shared_ptr actor_handle(new ActorHandle(actor_handle_serialized)); + STREAMING_CHECK(actor_handle != nullptr); + STREAMING_LOG(INFO) << "Actor id from handle: " << actor_handle->GetActorID(); + + STREAMING_LOG(INFO) << "HandleInitTask queues:"; + for (auto qid : message->QueueIds()) { + STREAMING_LOG(INFO) << "queue: " << qid; + } + for (auto qid : message->RescaleQueueIds()) { + STREAMING_LOG(INFO) << "rescale queue: " << qid; + } + + test_suite_ = TestSuiteFactory::CreateTestSuite(message); + STREAMING_CHECK(test_suite_ != nullptr); + } + + private: + std::shared_ptr reader_client_; + std::shared_ptr writer_client_; + std::shared_ptr test_thread_; + std::shared_ptr test_suite_; + std::shared_ptr peer_actor_handle_; +}; + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + RAY_CHECK(argc >= 4); + auto store_socket = std::string(argv[1]); + auto raylet_socket = std::string(argv[2]); + auto node_manager_port = std::stoi(std::string(argv[3])); + // auto runtime_env_hash = std::string(argv[4]); // Unused in this test + auto startup_token_str = std::string(argv[4]); + auto start = startup_token_str.find(std::string("=")) + 1; + auto startup_token = std::stoi(startup_token_str.substr(start)); + + ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); + ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port, + gcs_options, startup_token); + worker.RunTaskExecutionLoop(); + return 0; +} diff --git a/streaming/src/test/mock_transfer_tests.cc b/streaming/src/test/mock_transfer_tests.cc new file mode 100644 index 00000000..da0f232b --- /dev/null +++ b/streaming/src/test/mock_transfer_tests.cc @@ -0,0 +1,191 @@ +#include "data_reader.h" +#include "data_writer.h" +#include "gtest/gtest.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingMockTransfer, mock_produce_consume) { + std::shared_ptr transfer_config; + ObjectID channel_id = ObjectID::FromRandom(); + ProducerChannelInfo producer_channel_info; + producer_channel_info.channel_id = channel_id; + producer_channel_info.current_message_id = 0; + MockProducer producer(transfer_config, producer_channel_info); + + ConsumerChannelInfo consumer_channel_info; + consumer_channel_info.channel_id = channel_id; + MockConsumer consumer(transfer_config, consumer_channel_info); + + producer.CreateTransferChannel(); + uint8_t data[3] = {1, 2, 3}; + producer.ProduceItemToChannel(data, 3); + uint8_t *data_consumed; + uint32_t data_size_consumed; + consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); + EXPECT_EQ(data_size_consumed, 3); + EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0); + consumer.NotifyChannelConsumed(1); + + auto status = consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); + EXPECT_EQ(status, StreamingStatus::NoSuchItem); +} + +class StreamingTransferTest : public ::testing::Test { + public: + StreamingTransferTest() { + writer_runtime_context = std::make_shared(); + reader_runtime_context = std::make_shared(); + writer_runtime_context->MarkMockTest(); + reader_runtime_context->MarkMockTest(); + writer = std::make_shared(writer_runtime_context); + reader = std::make_shared(reader_runtime_context); + } + virtual ~StreamingTransferTest() = default; + void InitTransfer(int channel_num = 1) { + for (int i = 0; i < channel_num; ++i) { + queue_vec.push_back(ObjectID::FromRandom()); + } + std::vector channel_id_vec(queue_vec.size(), 0); + std::vector queue_size_vec(queue_vec.size(), 10000); + std::vector params(queue_vec.size()); + std::vector creation_status; + writer->Init(queue_vec, params, channel_id_vec, queue_size_vec); + reader->Init(queue_vec, params, channel_id_vec, creation_status, -1); + } + void DestroyTransfer() { + writer.reset(); + reader.reset(); + } + + protected: + std::shared_ptr writer; + std::shared_ptr reader; + std::vector queue_vec; + std::shared_ptr writer_runtime_context; + std::shared_ptr reader_runtime_context; +}; + +TEST_F(StreamingTransferTest, exchange_single_channel_test) { + InitTransfer(); + writer->Run(); + uint8_t data[4] = {1, 2, 3, 0xff}; + uint32_t data_size = 4; + writer->WriteMessageToBufferRing(queue_vec[0], data, data_size); + std::shared_ptr msg; + reader->GetBundle(5000, msg); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + auto &message = message_list.front(); + EXPECT_EQ(std::memcmp(message->Payload(), data, data_size), 0); +} + +TEST_F(StreamingTransferTest, exchange_multichannel_test) { + int channel_num = 4; + InitTransfer(4); + writer->Run(); + for (int i = 0; i < channel_num; ++i) { + uint8_t data[4] = {1, 2, 3, (uint8_t)i}; + uint32_t data_size = 4; + writer->WriteMessageToBufferRing(queue_vec[i], data, data_size); + std::shared_ptr msg; + reader->GetBundle(5000, msg); + EXPECT_EQ(msg->from, queue_vec[i]); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + auto &message = message_list.front(); + EXPECT_EQ(std::memcmp(message->Payload(), data, data_size), 0); + } +} + +TEST_F(StreamingTransferTest, exchange_consumed_test) { + InitTransfer(); + writer->Run(); + uint32_t data_size = 8196; + std::shared_ptr data(new uint8_t[data_size]); + auto func = [data, data_size](int index) { std::fill_n(data.get(), data_size, index); }; + + size_t num = 10000; + std::thread write_thread([this, data, data_size, &func, num]() { + for (size_t i = 0; i < num; ++i) { + func(i); + writer->WriteMessageToBufferRing(queue_vec[0], data.get(), data_size); + } + }); + + std::list read_message_list; + while (read_message_list.size() < num) { + std::shared_ptr msg; + reader->GetBundle(5000, msg); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + std::copy(message_list.begin(), message_list.end(), + std::back_inserter(read_message_list)); + } + int index = 0; + for (auto &message : read_message_list) { + func(index++); + EXPECT_EQ(std::memcmp(message->Payload(), data.get(), data_size), 0); + } + write_thread.join(); +} + +TEST_F(StreamingTransferTest, flow_control_test) { + InitTransfer(); + writer->Run(); + uint32_t data_size = 8196; + std::shared_ptr data(new uint8_t[data_size]); + auto func = [data, data_size](int index) { std::fill_n(data.get(), data_size, index); }; + + size_t num = 10000; + std::thread write_thread([this, data, data_size, &func, num]() { + for (size_t i = 0; i < num; ++i) { + func(i); + writer->WriteMessageToBufferRing(queue_vec[0], data.get(), data_size); + } + }); + std::unordered_map *writer_offset_info = nullptr; + std::unordered_map *reader_offset_info = nullptr; + writer->GetOffsetInfo(writer_offset_info); + reader->GetOffsetInfo(reader_offset_info); + uint32_t writer_step = writer_runtime_context->GetConfig().GetWriterConsumedStep(); + uint32_t reader_step = reader_runtime_context->GetConfig().GetReaderConsumedStep(); + uint64_t &writer_current_msg_id = + (*writer_offset_info)[queue_vec[0]].current_message_id; + uint64_t &writer_last_commit_id = + (*writer_offset_info)[queue_vec[0]].message_last_commit_id; + uint64_t &writer_target_msg_id = + (*writer_offset_info)[queue_vec[0]].queue_info.target_message_id; + uint64_t &reader_target_msg_id = + (*reader_offset_info)[queue_vec[0]].queue_info.target_message_id; + do { + std::this_thread::sleep_for( + std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT)); + STREAMING_LOG(INFO) << "Writer currrent msg id " << writer_current_msg_id + << ", writer target_msg_id=" << writer_target_msg_id + << ", consumer step " << writer_step; + } while (writer_current_msg_id < writer_step); + + std::list read_message_list; + while (read_message_list.size() < num) { + std::shared_ptr msg; + reader->GetBundle(1000, msg); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + std::copy(message_list.begin(), message_list.end(), + std::back_inserter(read_message_list)); + ASSERT_GE(writer_step, writer_last_commit_id - msg->meta->GetLastMessageId()); + ASSERT_GE(msg->meta->GetLastMessageId() + reader_step, reader_target_msg_id); + } + int index = 0; + for (auto &message : read_message_list) { + func(index++); + EXPECT_EQ(std::memcmp(message->Payload(), data.get(), data_size), 0); + } + write_thread.join(); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/queue_protobuf_tests.cc b/streaming/src/test/queue_protobuf_tests.cc new file mode 100644 index 00000000..8acf8eeb --- /dev/null +++ b/streaming/src/test/queue_protobuf_tests.cc @@ -0,0 +1,30 @@ +#include +#include +#include "gtest/gtest.h" + +#include "queue/message.h" +using namespace ray; +using namespace ray::streaming; + +TEST(ProtoBufTest, MessageCommonTest) { + JobID job_id = JobID::FromInt(0); + TaskID task_id = TaskID::ForDriverTask(job_id); + ray::ActorID actor_id = ray::ActorID::Of(job_id, task_id, 0); + ray::ActorID peer_actor_id = ray::ActorID::Of(job_id, task_id, 1); + ObjectID queue_id = ray::ObjectID::FromRandom(); + + uint8_t data[128]; + std::shared_ptr buffer = + std::make_shared(data, 128, true); + DataMessage msg(actor_id, peer_actor_id, queue_id, 100, 1000, 2000, buffer, true); + std::unique_ptr serilized_buffer = msg.ToBytes(); + std::shared_ptr msg2 = DataMessage::FromBytes(serilized_buffer->Data()); + EXPECT_EQ(msg.ActorId(), msg2->ActorId()); + EXPECT_EQ(msg.PeerActorId(), msg2->PeerActorId()); + EXPECT_EQ(msg.QueueId(), msg2->QueueId()); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h new file mode 100644 index 00000000..4aec671a --- /dev/null +++ b/streaming/src/test/queue_tests_base.h @@ -0,0 +1,288 @@ +#pragma once + +#include "hiredis/hiredis.h" +#include "ray/common/common_protocol.h" +#include "ray/common/test_util.h" +#include "ray/util/filesystem.h" + +namespace ray { +namespace streaming { + +using namespace ray::core; + +ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } + +static void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); + redisFree(context); +} +/// Base class for real-world tests with streaming queue +class StreamingQueueTestBase : public ::testing::TestWithParam { + public: + StreamingQueueTestBase(int num_nodes, int port) + : gcs_options_("127.0.0.1", 6379, ""), node_manager_port_(port) { + TestSetupUtil::StartUpRedisServers(std::vector{6379, 6380}); + + // flush redis first. + flushall_redis(); + + RAY_CHECK(num_nodes >= 0); + if (num_nodes > 0) { + raylet_socket_names_.resize(num_nodes); + raylet_store_socket_names_.resize(num_nodes); + } + + // start gcs server + gcs_server_socket_name_ = TestSetupUtil::StartGcsServer("127.0.0.1"); + + // start raylet on each node. Assign each node with different resources so that + // a task can be scheduled to the desired node. + for (int i = 0; i < num_nodes; i++) { + raylet_socket_names_[i] = + TestSetupUtil::StartRaylet("127.0.0.1", node_manager_port_ + i, "127.0.0.1", + "\"CPU,4.0,resource" + std::to_string(i) + ",10\"", + &raylet_store_socket_names_[i]); + } + } + + ~StreamingQueueTestBase() { + STREAMING_LOG(INFO) << "Stop raylet store and actors"; + for (const auto &raylet_socket_name : raylet_socket_names_) { + TestSetupUtil::StopRaylet(raylet_socket_name); + } + + TestSetupUtil::StopGcsServer(gcs_server_socket_name_); + TestSetupUtil::ShutDownRedisServers(); + } + + JobID NextJobId() const { + static uint32_t job_counter = 1; + return JobID::FromInt(job_counter++); + } + + void InitWorker(ActorID &self_actor_id, ActorID &peer_actor_id, + const queue::protobuf::StreamingQueueTestRole role, + const std::vector &queue_ids, + const std::vector &rescale_queue_ids, std::string suite_name, + std::string test_name, uint64_t param) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); + std::string forked_serialized_str; + ObjectID actor_handle_id; + Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str, + &actor_handle_id); + STREAMING_CHECK(st.ok()); + STREAMING_LOG(INFO) << "forked_serialized_str: " << forked_serialized_str; + TestInitMessage msg(role, self_actor_id, peer_actor_id, forked_serialized_str, + queue_ids, rescale_queue_ids, suite_name, test_name, param); + + std::vector> args; + args.emplace_back(new TaskArgByValue(std::make_shared( + msg.ToBytes(), nullptr, std::vector(), true))); + std::unordered_map resources; + TaskOptions options{"", 0, resources}; + RayFunction func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("", "", "init", "")}; + + RAY_UNUSED(driver.SubmitActorTask(self_actor_id, func, args, options)); + } + + void SubmitTestToActor(ActorID &actor_id, const std::string test) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); + uint8_t data[8]; + auto buffer = std::make_shared(data, 8, true); + std::vector> args; + args.emplace_back(new TaskArgByValue(std::make_shared( + buffer, nullptr, std::vector(), true))); + std::unordered_map resources; + TaskOptions options("", 0, resources); + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "", test, "execute_test", "")}; + + RAY_UNUSED(driver.SubmitActorTask(actor_id, func, args, options)); + } + + bool CheckCurTest(ActorID &actor_id, const std::string test_name) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); + uint8_t data[8]; + auto buffer = std::make_shared(data, 8, true); + std::vector> args; + args.emplace_back(new TaskArgByValue(std::make_shared( + buffer, nullptr, std::vector(), true))); + std::unordered_map resources; + TaskOptions options{"", 1, resources}; + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "", "", "check_current_test_status", "")}; + + auto return_refs = driver.SubmitActorTask(actor_id, func, args, options); + auto return_ids = ObjectRefsToIds(return_refs.value()); + + std::vector wait_results; + std::vector> results; + Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results, true); + if (!wait_st.ok()) { + STREAMING_LOG(ERROR) << "Wait fail."; + return false; + } + STREAMING_CHECK(wait_results.size() >= 1); + if (!wait_results[0]) { + STREAMING_LOG(WARNING) << "Wait direct call fail."; + return false; + } + + Status get_st = driver.Get(return_ids, -1, &results); + if (!get_st.ok()) { + STREAMING_LOG(ERROR) << "Get fail."; + return false; + } + STREAMING_CHECK(results.size() >= 1); + if (results[0]->IsException()) { + STREAMING_LOG(INFO) << "peer actor may has exceptions."; + return false; + } + STREAMING_CHECK(results[0]->HasData()); + STREAMING_LOG(DEBUG) << "SendForResult result[0] DataSize: " << results[0]->GetSize(); + + const std::shared_ptr result_buffer = results[0]->GetData(); + std::shared_ptr return_buffer = + std::make_shared(result_buffer->Data(), result_buffer->Size(), + true); + + uint8_t *bytes = result_buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum); + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + STREAMING_CHECK(*type == queue::protobuf::StreamingQueueMessageType:: + StreamingQueueTestCheckStatusRspMsgType); + std::shared_ptr message = + TestCheckStatusRspMsg::FromBytes(bytes); + STREAMING_CHECK(message->TestName() == test_name); + return message->Status(); + } + + ActorID CreateActorHelper(const std::unordered_map &resources, + bool is_direct_call, int64_t max_restarts) { + std::unique_ptr actor_handle; + + // Test creating actor. + uint8_t array[] = {1, 2, 3}; + auto buffer = std::make_shared(array, sizeof(array)); + + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "", "", "actor creation task", "")}; + std::vector> args; + args.emplace_back(new TaskArgByValue(std::make_shared( + buffer, nullptr, std::vector()))); + + std::string name = ""; + std::string ray_namespace = ""; + rpc::SchedulingStrategy scheduling_strategy; + scheduling_strategy.mutable_default_scheduling_strategy(); + ActorCreationOptions actor_options{ + max_restarts, + /*max_task_retries=*/0, + /*max_concurrency=*/1, resources, resources, {}, + /*is_detached=*/false, name, ray_namespace, /*is_asyncio=*/false, + scheduling_strategy}; + // Create an actor. + ActorID actor_id; + RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor( + func, args, actor_options, /*extension_data*/ "", &actor_id)); + return actor_id; + } + + void SubmitTest(uint32_t queue_num, std::string suite_name, std::string test_name, + uint64_t timeout_ms) { + std::vector queue_id_vec; + std::vector rescale_queue_id_vec; + for (uint32_t i = 0; i < queue_num; ++i) { + ObjectID queue_id = ray::ObjectID::FromRandom(); + queue_id_vec.emplace_back(queue_id); + } + + // One scale id + ObjectID rescale_queue_id = ray::ObjectID::FromRandom(); + rescale_queue_id_vec.emplace_back(rescale_queue_id); + + std::vector channel_seq_id_vec(queue_num, 0); + + for (size_t i = 0; i < queue_id_vec.size(); ++i) { + STREAMING_LOG(INFO) << " qid hex => " << queue_id_vec[i].Hex(); + } + for (auto &qid : rescale_queue_id_vec) { + STREAMING_LOG(INFO) << " rescale qid hex => " << qid.Hex(); + } + STREAMING_LOG(INFO) << "Sub process: writer."; + + // You must keep it same with `src/ray/core_worker/core_worker.h:CoreWorkerOptions` + CoreWorkerOptions options; + options.worker_type = WorkerType::DRIVER; + options.language = Language::PYTHON; + options.store_socket = raylet_store_socket_names_[0]; + options.raylet_socket = raylet_socket_names_[0]; + options.job_id = NextJobId(); + options.gcs_options = gcs_options_; + options.enable_logging = true; + options.install_failure_signal_handler = true; + options.node_ip_address = "127.0.0.1"; + options.node_manager_port = node_manager_port_; + options.raylet_ip_address = "127.0.0.1"; + options.driver_name = "queue_tests"; + options.num_workers = 1; + options.metrics_agent_port = -1; + InitShutdownRAII core_worker_raii(CoreWorkerProcess::Initialize, + CoreWorkerProcess::Shutdown, options); + + // Create writer and reader actors + std::unordered_map resources; + auto actor_id_writer = CreateActorHelper(resources, true, 0); + auto actor_id_reader = CreateActorHelper(resources, true, 0); + + InitWorker(actor_id_writer, actor_id_reader, + queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec, + rescale_queue_id_vec, suite_name, test_name, GetParam()); + InitWorker(actor_id_reader, actor_id_writer, + queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec, + rescale_queue_id_vec, suite_name, test_name, GetParam()); + + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + SubmitTestToActor(actor_id_writer, test_name); + SubmitTestToActor(actor_id_reader, test_name); + + uint64_t slept_time_ms = 0; + while (slept_time_ms < timeout_ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000)); + STREAMING_LOG(INFO) << "Check test status."; + if (CheckCurTest(actor_id_writer, test_name) && + CheckCurTest(actor_id_reader, test_name)) { + STREAMING_LOG(INFO) << "Test Success, Exit."; + return; + } + slept_time_ms += 5 * 1000; + } + + EXPECT_TRUE(false); + STREAMING_LOG(INFO) << "Test Timeout, Exit."; + } + + void SetUp() {} + + void TearDown() {} + + protected: + std::vector raylet_socket_names_; + std::vector raylet_store_socket_names_; + gcs::GcsClientOptions gcs_options_; + int node_manager_port_; + std::string gcs_server_socket_name_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/test/ring_buffer_tests.cc b/streaming/src/test/ring_buffer_tests.cc new file mode 100644 index 00000000..1a3b05c3 --- /dev/null +++ b/streaming/src/test/ring_buffer_tests.cc @@ -0,0 +1,92 @@ +#include +#include +#include + +#include "gtest/gtest.h" +#include "message/message.h" +#include "ray/util/logging.h" +#include "ring_buffer/ring_buffer.h" + +using namespace ray; +using namespace ray::streaming; + +size_t data_n = 1000000; +TEST(StreamingRingBufferTest, streaming_message_ring_buffer_test) { + for (int k = 0; k < 10000; ++k) { + StreamingRingBuffer ring_buffer(3, StreamingRingBufferType::SPSC_LOCK); + for (int i = 0; i < 5; ++i) { + uint8_t data[] = {1, 1, 3}; + data[0] = i; + StreamingMessagePtr message = + std::make_shared(data, 3, i, StreamingMessageType::Message); + EXPECT_EQ(ring_buffer.Push(message), true); + size_t ith = i >= 3 ? 3 : (i + 1); + EXPECT_EQ(ring_buffer.Size(), ith); + } + int th = 2; + + while (!ring_buffer.IsEmpty()) { + StreamingMessagePtr message_ptr = ring_buffer.Front(); + ring_buffer.Pop(); + EXPECT_EQ(message_ptr->PayloadSize(), 3); + EXPECT_EQ(*(message_ptr->Payload()), th++); + } + } +} + +TEST(StreamingRingBufferTest, spsc_test) { + size_t m_num = 1000; + StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC); + std::thread thread([&ring_buffer]() { + for (size_t j = 0; j < data_n; ++j) { + StreamingMessagePtr message = std::make_shared( + reinterpret_cast(&j), static_cast(sizeof(size_t)), j, + StreamingMessageType::Message); + while (ring_buffer.IsFull()) { + } + ring_buffer.Push(message); + } + }); + size_t count = 0; + while (count < data_n) { + while (ring_buffer.IsEmpty()) { + } + auto &msg = ring_buffer.Front(); + EXPECT_EQ(std::memcmp(msg->Payload(), &count, sizeof(size_t)), 0); + ring_buffer.Pop(); + count++; + } + thread.join(); + EXPECT_EQ(count, data_n); +} + +TEST(StreamingRingBufferTest, mutex_test) { + size_t m_num = data_n; + StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC_LOCK); + std::thread thread([&ring_buffer]() { + for (size_t j = 0; j < data_n; ++j) { + StreamingMessagePtr message = std::make_shared( + reinterpret_cast(&j), static_cast(sizeof(size_t)), j, + StreamingMessageType::Message); + while (ring_buffer.IsFull()) { + } + ring_buffer.Push(message); + } + }); + size_t count = 0; + while (count < data_n) { + while (ring_buffer.IsEmpty()) { + } + auto msg = ring_buffer.Front(); + EXPECT_EQ(std::memcmp(msg->Payload(), &count, sizeof(size_t)), 0); + ring_buffer.Pop(); + count++; + } + thread.join(); + EXPECT_EQ(count, data_n); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/run_streaming_queue_test.sh b/streaming/src/test/run_streaming_queue_test.sh new file mode 100755 index 00000000..e777ffd5 --- /dev/null +++ b/streaming/src/test/run_streaming_queue_test.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# Run all streaming c++ tests using streaming queue, instead of plasma queue +# This needs to be run in the root directory. + +# Try to find an unused port for raylet to use. +PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009" +RAYLET_PORT=0 +for port in $PORTS; do + if ! nc -z localhost "$port"; then + RAYLET_PORT=$port + break + fi +done + +if [[ $RAYLET_PORT == 0 ]]; then + echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests." + exit +fi + +# Cause the script to exit if a single command fails. +set -e +set -x + +# Get the directory in which this script is executing. +SCRIPT_DIR="$(dirname "$0")" + +# Get the directory in which this script is executing. +SCRIPT_DIR="$(dirname "$0")" +RAY_ROOT="$SCRIPT_DIR/../../.." +# Makes $RAY_ROOT an absolute path. +RAY_ROOT="$(cd "$RAY_ROOT" && pwd)" +if [ -z "$RAY_ROOT" ] ; then + exit 1 +fi + +bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:gcs_server" "//:redis-server" "//:redis-cli" +bazel build //streaming:streaming_test_worker +bazel build //streaming:streaming_queue_tests + +# Ensure we're in the right directory. +if [ ! -d "$RAY_ROOT/python" ]; then + echo "Unable to find root Ray directory. Has this script moved?" + exit 1 +fi + +REDIS_SERVER_EXEC="$RAY_ROOT/bazel-bin/external/com_github_antirez_redis/redis-server" +REDIS_CLIENT_EXEC="$RAY_ROOT/bazel-bin/redis-cli" +RAYLET_EXEC="$RAY_ROOT/bazel-bin/raylet" +STREAMING_TEST_WORKER_EXEC="$RAY_ROOT/bazel-bin/streaming/streaming_test_worker" +GCS_SERVER_EXEC="$RAY_ROOT/bazel-bin/gcs_server" + +# clear env +set +e +pgrep "plasma|DefaultDriver|DefaultWorker|AppStarter|redis|http_server|job_agent" | xargs kill -9 &> /dev/null +set -e + +# Run tests. + +# to run specific test, add --gtest_filter, below is an example +#$RAY_ROOT/bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC $REDIS_SERVER_EXEC $REDIS_CLIENT_EXEC --gtest_filter=StreamingTest/StreamingWriterTest.streaming_writer_exactly_once_test/0 + +# run all tests +"$RAY_ROOT"/bazel-bin/streaming/streaming_queue_tests "$RAYLET_EXEC" "$RAYLET_PORT" "$STREAMING_TEST_WORKER_EXEC" "$GCS_SERVER_EXEC" "$REDIS_SERVER_EXEC" "$REDIS_CLIENT_EXEC" +sleep 1s diff --git a/streaming/src/test/streaming_perf_tests.cc b/streaming/src/test/streaming_perf_tests.cc new file mode 100644 index 00000000..3c4752b0 --- /dev/null +++ b/streaming/src/test/streaming_perf_tests.cc @@ -0,0 +1,120 @@ +#include +#include +#include +#include + +#include "opencensus/stats/internal/delta_producer.h" +#include "opencensus/stats/internal/stats_exporter_impl.h" +#include "ray/stats/stats.h" + +#include "config/streaming_config.h" +#include "gtest/gtest.h" +#include "metrics/streaming_perf_metric.h" + +using namespace ray::streaming; +using namespace ray; + +class StreamingReporterCounterTest : public ::testing::Test { + public: + using UpdateFunc = std::function; + + void SetUp() { + uint32_t kReportFlushInterval = 100; + absl::Duration report_interval = absl::Milliseconds(kReportFlushInterval); + absl::Duration harvest_interval = absl::Milliseconds(kReportFlushInterval / 2); + ray::stats::StatsConfig::instance().SetReportInterval(report_interval); + ray::stats::StatsConfig::instance().SetHarvestInterval(harvest_interval); + const stats::TagsType global_tags = {{stats::ResourceNameKey, "CPU"}}; + std::shared_ptr exporter( + new stats::StdoutExporterClient()); + ray::stats::Init(global_tags, 10054, exporter); + + setenv("STREAMING_METRICS_MODE", "DEV", 1); + setenv("ENABLE_RAY_STATS", "ON", 1); + setenv("STREAMING_ENABLE_METRICS", "ON", 1); + perf_counter_.reset(new StreamingReporter()); + + const std::unordered_map default_tags = { + {"app", "s_test"}, {"cluster", "kmon-dev"}}; + metrics_conf_.SetMetricsGlobalTags(default_tags); + perf_counter_->Start(metrics_conf_); + } + + void TearDown() { + opencensus::stats::DeltaProducer::Get()->Flush(); + opencensus::stats::StatsExporterImpl::Get()->Export(); + perf_counter_->Shutdown(); + ray::stats::Shutdown(); + } + + void RegisterAndRun(UpdateFunc update_handler) { + auto stat_time_handler = [this](size_t thread_index, UpdateFunc update_handler) { + auto start = std::chrono::high_resolution_clock::now(); + + for (size_t loop_index = 0; loop_index < loop_update_times_; ++loop_index) { + update_handler(loop_index); + } + + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + + std::string info = "Thread=" + std::to_string(thread_index) + + ", times=" + std::to_string(loop_update_times_) + + ", cost=" + std::to_string(elapsed.count()) + "us."; + std::cout << info << std::endl; + }; + + for (size_t thread_index = 0; thread_index < op_thread_count_; ++thread_index) { + thread_pool_.emplace_back( + std::bind(stat_time_handler, thread_index, update_handler)); + } + + for (auto &thread : thread_pool_) { + thread.join(); + } + } + + protected: + size_t op_thread_count_{4}; + size_t loop_update_times_{10}; + std::vector thread_pool_; + + StreamingMetricsConfig metrics_conf_; + std::unique_ptr perf_counter_; +}; + +TEST_F(StreamingReporterCounterTest, UpdateCounterWithOneKeyTest) { + RegisterAndRun([this](size_t loop_index) { + perf_counter_->UpdateCounter("domaina", "groupa", "a", loop_index); + }); +} + +TEST_F(StreamingReporterCounterTest, UpdateCounterTest) { + RegisterAndRun([this](size_t loop_index) { + auto loop_index_str = std::to_string(loop_index % 10); + perf_counter_->UpdateCounter("domaina" + loop_index_str, "groupa" + loop_index_str, + "a" + loop_index_str, loop_index); + }); +} + +TEST_F(StreamingReporterCounterTest, UpdateGaugeWithOneKeyTest) { + RegisterAndRun([this](size_t loop_index) { + std::unordered_map tags; + tags["tag1"] = "tag1"; + tags["tag2"] = std::to_string(loop_index); + perf_counter_->UpdateGauge("streaming.test.gauge", tags, loop_index); + }); +} + +TEST_F(StreamingReporterCounterTest, UpdateGaugeTest) { + RegisterAndRun([this](size_t loop_index) { + auto loop_index_str = std::to_string(loop_index % 10); + perf_counter_->UpdateGauge("domaina" + loop_index_str, "groupa" + loop_index_str, + "a" + loop_index_str, loop_index); + }); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc new file mode 100644 index 00000000..e30cc725 --- /dev/null +++ b/streaming/src/test/streaming_queue_tests.cc @@ -0,0 +1,80 @@ +#define BOOST_BIND_NO_PLACEHOLDERS +#include "data_reader.h" +#include "data_writer.h" +#include "gtest/gtest.h" +#include "message/message.h" +#include "message/message_bundle.h" +#include "queue/queue_client.h" +#include "ray/common/test_util.h" +#include "ray/core_worker/core_worker.h" +#include "ring_buffer/ring_buffer.h" +#include "test/queue_tests_base.h" + +using namespace std::placeholders; +namespace ray { +namespace streaming { + +static int node_manager_port; + +class StreamingQueueTest : public StreamingQueueTestBase { + public: + StreamingQueueTest() : StreamingQueueTestBase(1, node_manager_port) {} +}; + +class StreamingWriterTest : public StreamingQueueTestBase { + public: + StreamingWriterTest() : StreamingQueueTestBase(1, node_manager_port) {} +}; + +class StreamingExactlySameTest : public StreamingQueueTestBase { + public: + StreamingExactlySameTest() : StreamingQueueTestBase(1, node_manager_port) {} +}; + +TEST_P(StreamingQueueTest, PullPeerAsyncTest) { + STREAMING_LOG(INFO) << "StreamingQueueTest.pull_peer_async_test"; + + uint32_t queue_num = 1; + SubmitTest(queue_num, "StreamingQueueTest", "pull_peer_async_test", 60 * 1000); +} + +TEST_P(StreamingQueueTest, GetQueueTest) { + STREAMING_LOG(INFO) << "StreamingQueueTest.get_queue_test"; + + uint32_t queue_num = 1; + SubmitTest(queue_num, "StreamingQueueTest", "get_queue_test", 60 * 1000); +} + +TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) { + STREAMING_LOG(INFO) << "StreamingWriterTest.streaming_writer_exactly_once_test"; + + uint32_t queue_num = 1; + + STREAMING_LOG(INFO) << "Streaming Strategy => EXACTLY ONCE"; + SubmitTest(queue_num, "StreamingWriterTest", "streaming_writer_exactly_once_test", + 60 * 1000); +} + +TEST_P(StreamingExactlySameTest, Hold) {} + +INSTANTIATE_TEST_SUITE_P(StreamingTest, StreamingQueueTest, testing::Values(0)); + +INSTANTIATE_TEST_SUITE_P(StreamingTest, StreamingWriterTest, testing::Values(0)); + +INSTANTIATE_TEST_SUITE_P(StreamingTest, StreamingExactlySameTest, + testing::Values(0, 1, 5, 9)); + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 7); + ray::TEST_RAYLET_EXEC_PATH = std::string(argv[1]); + ray::streaming::node_manager_port = std::stoi(std::string(argv[2])); + ray::TEST_MOCK_WORKER_EXEC_PATH = std::string(argv[3]); + ray::TEST_GCS_SERVER_EXEC_PATH = std::string(argv[4]); + ray::TEST_REDIS_SERVER_EXEC_PATH = std::string(argv[5]); + ray::TEST_REDIS_CLIENT_EXEC_PATH = std::string(argv[6]); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/streaming_util_tests.cc b/streaming/src/test/streaming_util_tests.cc new file mode 100644 index 00000000..27e2ce7d --- /dev/null +++ b/streaming/src/test/streaming_util_tests.cc @@ -0,0 +1,23 @@ +#include "gtest/gtest.h" +#include "util/streaming_util.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingUtilTest, test_Byte2hex) { + const uint8_t data[2] = {0x11, 0x07}; + EXPECT_TRUE(Util::Byte2hex(data, 2) == "1107"); + EXPECT_TRUE(Util::Byte2hex(data, 2) != "1108"); +} + +TEST(StreamingUtilTest, test_Hex2str) { + const uint8_t data[2] = {0x11, 0x07}; + EXPECT_TRUE(std::memcmp(Util::Hexqid2str("1107").c_str(), data, 2) == 0); + const uint8_t data2[2] = {0x10, 0x0f}; + EXPECT_TRUE(std::memcmp(Util::Hexqid2str("100f").c_str(), data2, 2) == 0); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/util/config.cc b/streaming/src/util/config.cc new file mode 100644 index 00000000..b5c3e320 --- /dev/null +++ b/streaming/src/util/config.cc @@ -0,0 +1,20 @@ +#include "config.h" +namespace ray { +namespace streaming { + +boost::any &Config::Get(ConfigEnum key) const { + auto item = config_map_.find(key); + STREAMING_CHECK(item != config_map_.end()); + return item->second; +} + +boost::any Config::Get(ConfigEnum key, boost::any default_value) const { + auto item = config_map_.find(key); + if (item == config_map_.end()) { + return default_value; + } + return item->second; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/config.h b/streaming/src/util/config.h new file mode 100644 index 00000000..56c6af81 --- /dev/null +++ b/streaming/src/util/config.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include + +#include "streaming_logging.h" + +namespace ray { +namespace streaming { +enum class ConfigEnum : uint32_t { + QUEUE_ID_VECTOR = 0, + MIN = QUEUE_ID_VECTOR, + MAX = QUEUE_ID_VECTOR +}; +} +} // namespace ray + +namespace std { +template <> +struct hash<::ray::streaming::ConfigEnum> { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; + +template <> +struct hash { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; +} // namespace std + +namespace ray { +namespace streaming { + +class Config { + public: + template + inline void Set(ConfigEnum key, const ValueType &any) { + config_map_.emplace(key, any); + } + + template + inline void Set(ConfigEnum key, ValueType &&any) { + config_map_.emplace(key, any); + } + + template + inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { + auto item = config_map_.find(key); + if (item != config_map_.end()) { + return item->second; + } + Set(key, any); + return any; + } + + boost::any &Get(ConfigEnum key) const; + boost::any Get(ConfigEnum key, boost::any default_value) const; + + inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline std::string GetString(ConfigEnum key) { + return boost::any_cast(Get(key)); + } + + virtual ~Config() = default; + + protected: + mutable std::unordered_map config_map_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_logging.cc b/streaming/src/util/streaming_logging.cc new file mode 100644 index 00000000..fda4824c --- /dev/null +++ b/streaming/src/util/streaming_logging.cc @@ -0,0 +1,5 @@ +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming {} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_logging.h b/streaming/src/util/streaming_logging.h new file mode 100644 index 00000000..11525d37 --- /dev/null +++ b/streaming/src/util/streaming_logging.h @@ -0,0 +1,9 @@ +#pragma once + +#include "ray/util/logging.h" + +#define STREAMING_LOG RAY_LOG +#define STREAMING_CHECK RAY_CHECK +namespace ray { +namespace streaming {} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_util.cc b/streaming/src/util/streaming_util.cc new file mode 100644 index 00000000..4f2a1353 --- /dev/null +++ b/streaming/src/util/streaming_util.cc @@ -0,0 +1,27 @@ +#include "util/streaming_util.h" + +#include +namespace ray { +namespace streaming { +std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) { + constexpr char hex[] = "0123456789abcdef"; + std::string result; + for (uint32_t i = 0; i < data_size; i++) { + unsigned short val = data[i]; + result.push_back(hex[val >> 4]); + result.push_back(hex[val & 0xf]); + } + return result; +} + +std::string Util::Hexqid2str(const std::string &q_id_hex) { + std::string result; + for (uint32_t i = 0; i < q_id_hex.size(); i += 2) { + std::string byte = q_id_hex.substr(i, 2); + char chr = static_cast(std::strtol(byte.c_str(), nullptr, 16)); + result.push_back(chr); + } + return result; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_util.h b/streaming/src/util/streaming_util.h new file mode 100644 index 00000000..bd703535 --- /dev/null +++ b/streaming/src/util/streaming_util.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include + +#include "ray/common/id.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +class Util { + public: + static std::string Byte2hex(const uint8_t *data, uint32_t data_size); + + static std::string Hexqid2str(const std::string &q_id_hex); + + template + static std::string join(const T &v, const std::string &delimiter, + const std::string &prefix = "", + const std::string &suffix = "") { + std::stringstream ss; + size_t i = 0; + ss << prefix; + for (const auto &elem : v) { + if (i != 0) { + ss << delimiter; + } + ss << elem; + i++; + } + ss << suffix; + return ss.str(); + } + + template + static std::string join(InputIterator first, InputIterator last, + const std::string &delim, const std::string &arround = "") { + std::string a = arround; + while (first != last) { + a += std::to_string(*first); + first++; + if (first != last) a += delim; + } + a += arround; + return a; + } + + template + static std::string join(InputIterator first, InputIterator last, + std::function func, + const std::string &delim, const std::string &arround = "") { + std::string a = arround; + while (first != last) { + a += func(first); + first++; + if (first != last) a += delim; + } + a += arround; + return a; + } +}; + +class AutoSpinLock { + public: + explicit AutoSpinLock(std::atomic_flag &lock) : lock_(lock) { + while (lock_.test_and_set(std::memory_order_acquire)) + ; + } + ~AutoSpinLock() { unlock(); } + void unlock() { lock_.clear(std::memory_order_release); } + + private: + std::atomic_flag &lock_; +}; + +inline void ConvertToValidQueueId(const ObjectID &queue_id) { + auto addr = const_cast(&queue_id); + *(reinterpret_cast(addr)) = 0; +} +} // namespace streaming +} // namespace ray